diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index c45a941d63af8..fc2d2c88a66a9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -23,8 +23,8 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobID; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -50,10 +50,10 @@ public interface Environment { * @return the ID of the job from the original job graph */ JobID getJobID(); - + /** * Gets the ID of the jobVertex that this task corresponds to. - * + * * @return The JobVertexID of this task. */ JobVertexID getJobVertexId(); @@ -130,18 +130,12 @@ public interface Environment { BroadcastVariableManager getBroadcastVariableManager(); - // ------------------------------------------------------------------------ - // Runtime result writers and readers - // ------------------------------------------------------------------------ - // The environment sets up buffer-oriented writers and readers, which the - // user can use to produce and consume results. - // ------------------------------------------------------------------------ - BufferWriter getWriter(int index); BufferWriter[] getAllWriters(); - BufferReader getReader(int index); + InputGate getInputGate(int index); + + InputGate[] getAllInputGates(); - BufferReader[] getAllReaders(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java index 85c0e8e270d3f..4fa490757feb3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/RuntimeEnvironment.java @@ -27,9 +27,10 @@ import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; import org.apache.flink.runtime.io.network.partition.IntermediateResultPartition; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobID; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -95,9 +96,9 @@ public class RuntimeEnvironment implements Environment, Runnable { private final BufferWriter[] writers; - private final BufferReader[] readers; + private final SingleInputGate[] inputGates; - private final Map readersById = new HashMap(); + private final Map inputGatesById = new HashMap(); public RuntimeEnvironment( ActorRef jobManager, Task owner, TaskDeploymentDescriptor tdd, ClassLoader userCodeClassLoader, @@ -128,15 +129,18 @@ public RuntimeEnvironment( // Consumed intermediate result partitions final List consumedPartitions = tdd.getConsumedPartitions(); - this.readers = new BufferReader[consumedPartitions.size()]; + this.inputGates = new SingleInputGate[consumedPartitions.size()]; - for (int i = 0; i < readers.length; i++) { - readers[i] = BufferReader.create(this, networkEnvironment, consumedPartitions.get(i)); + for (int i = 0; i < inputGates.length; i++) { + inputGates[i] = SingleInputGate.create(networkEnvironment, consumedPartitions.get(i)); - // The readers are organized by key for task updates/channel updates at runtime - readersById.put(readers[i].getConsumedResultId(), readers[i]); + // The input gates are organized by key for task updates/channel updates at runtime + inputGatesById.put(inputGates[i].getConsumedResultId(), inputGates[i]); } + this.jobConfiguration = tdd.getJobConfiguration(); + this.taskConfiguration = tdd.getTaskConfiguration(); + // ---------------------------------------------------------------- // Invokable setup // ---------------------------------------------------------------- @@ -163,9 +167,6 @@ public RuntimeEnvironment( throw new Exception("Could not instantiate the invokable class.", t); } - this.jobConfiguration = tdd.getJobConfiguration(); - this.taskConfiguration = tdd.getTaskConfiguration(); - this.invokable.setEnvironment(this); this.invokable.registerInputOutput(); } @@ -361,23 +362,23 @@ public BufferWriter[] getAllWriters() { } @Override - public BufferReader getReader(int index) { - checkElementIndex(index, readers.length, "Illegal environment reader request."); + public InputGate getInputGate(int index) { + checkElementIndex(index, inputGates.length); - return readers[index]; + return inputGates[index]; } @Override - public BufferReader[] getAllReaders() { - return readers; + public SingleInputGate[] getAllInputGates() { + return inputGates; } public IntermediateResultPartition[] getProducedPartitions() { return producedPartitions; } - public BufferReader getReaderById(IntermediateDataSetID id) { - return readersById.get(id); + public SingleInputGate getInputGateById(IntermediateDataSetID id) { + return inputGatesById.get(id); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 74a448ac3127c..2a8d6d46e949f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -21,7 +21,6 @@ import akka.actor.ActorRef; import akka.util.Timeout; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; @@ -29,6 +28,7 @@ import org.apache.flink.runtime.io.network.netty.NettyConnectionManager; import org.apache.flink.runtime.io.network.partition.IntermediateResultPartition; import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.taskmanager.NetworkEnvironmentConfiguration; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManager; @@ -154,14 +154,14 @@ public void registerTask(Task task) throws IOException { } // Setup the buffer pool for each buffer reader - final BufferReader[] readers = task.getReaders(); + final SingleInputGate[] inputGates = task.getInputGates(); - for (BufferReader reader : readers) { + for (SingleInputGate gate : inputGates) { BufferPool bufferPool = null; try { - bufferPool = networkBufferPool.createBufferPool(reader.getNumberOfInputChannels(), false); - reader.setBufferPool(bufferPool); + bufferPool = networkBufferPool.createBufferPool(gate.getNumberOfInputChannels(), false); + gate.setBufferPool(bufferPool); } catch (Throwable t) { if (bufferPool != null) { @@ -191,13 +191,13 @@ public void unregisterTask(Task task) { taskEventDispatcher.unregisterWriters(executionId); - final BufferReader[] readers = task.getReaders(); + final SingleInputGate[] inputGates = task.getInputGates(); - if (readers != null) { - for (BufferReader reader : readers) { + if (inputGates != null) { + for (SingleInputGate gate : inputGates) { try { - if (reader != null) { - reader.releaseAllResources(); + if (gate != null) { + gate.releaseAllResources(); } } catch (IOException e) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/event/EventNotificationHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java similarity index 65% rename from flink-runtime/src/main/java/org/apache/flink/runtime/util/event/EventNotificationHandler.java rename to flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java index a9b3a8678dfae..95fce96e71cbd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/event/EventNotificationHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java @@ -16,40 +16,42 @@ * limitations under the License. */ -package org.apache.flink.runtime.util.event; +package org.apache.flink.runtime.io.network.api; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.util.event.EventListener; /** * The event handler manages {@link EventListener} instances and allows to * to publish events to them. */ -public class EventNotificationHandler { +public class TaskEventHandler { // Listeners for each event type - private final Multimap, EventListener> listeners = HashMultimap.create(); + private final Multimap, EventListener> listeners = HashMultimap.create(); - public void subscribe(EventListener listener, Class eventType) { + public void subscribe(EventListener listener, Class eventType) { synchronized (listeners) { listeners.put(eventType, listener); } } - public void unsubscribe(EventListener listener, Class eventType) { + public void unsubscribe(EventListener listener, Class eventType) { synchronized (listeners) { listeners.remove(eventType, listener); } } /** - * Publishes the event to all subscribed {@link EventListener} objects. + * Publishes the task event to all subscribed event listeners.. * * @param event The event to publish. */ - public void publish(T event) { + public void publish(TaskEvent event) { synchronized (listeners) { - for (EventListener listener : listeners.get((Class) event.getClass())) { + for (EventListener listener : listeners.get(event.getClass())) { listener.onEvent(event); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractReader.java new file mode 100644 index 0000000000000..1bfca84061100 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractReader.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.api.reader; + +import org.apache.flink.runtime.event.task.AbstractEvent; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; +import org.apache.flink.runtime.io.network.api.TaskEventHandler; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.util.event.EventListener; + +import java.io.IOException; + +import static com.google.common.base.Preconditions.checkState; + +/** + * A basic reader implementation, which wraps an input gate and handles events. + */ +public abstract class AbstractReader implements ReaderBase { + + /** The input gate to read from. */ + protected final InputGate inputGate; + + /** The task event handler to manage task event subscriptions. */ + private final TaskEventHandler taskEventHandler = new TaskEventHandler(); + + /** Flag indicating whether this reader allows iteration events. */ + private boolean isIterative; + + /** + * The current number of end of superstep events (reset for each superstep). A superstep is + * finished after an end of superstep event has been received for each input channel. + */ + private int currentNumberOfEndOfSuperstepEvents; + + protected AbstractReader(InputGate inputGate) { + this.inputGate = inputGate; + } + + @Override + public boolean isFinished() { + return inputGate.isFinished(); + } + + // ------------------------------------------------------------------------ + // Events + // ------------------------------------------------------------------------ + + @Override + public void registerTaskEventListener(EventListener listener, Class eventType) { + taskEventHandler.subscribe(listener, eventType); + } + + @Override + public void sendTaskEvent(TaskEvent event) throws IOException { + inputGate.sendTaskEvent(event); + } + + /** + * Handles the event and returns whether the reader reached an end-of-stream event (either the + * end of the whole stream or the end of an superstep). + */ + protected boolean handleEvent(AbstractEvent event) throws IOException { + final Class eventType = event.getClass(); + + try { + // ------------------------------------------------------------ + // Runtime events + // ------------------------------------------------------------ + + // This event is also checked at the (single) input gate to release the respective + // channel, at which it was received. + if (eventType == EndOfPartitionEvent.class) { + return true; + } + else if (eventType == EndOfSuperstepEvent.class) { + return incrementEndOfSuperstepEventAndCheck(); + } + + // ------------------------------------------------------------ + // Task events (user) + // ------------------------------------------------------------ + else if (event instanceof TaskEvent) { + taskEventHandler.publish((TaskEvent) event); + + return false; + } + else { + throw new IllegalStateException("Received unexpected event of type " + eventType + " at reader."); + } + } + catch (Throwable t) { + throw new IOException("Error while handling event of type " + eventType + ": " + t.getMessage(), t); + } + } + + // ------------------------------------------------------------------------ + // Iterations + // ------------------------------------------------------------------------ + + @Override + public void setIterativeReader() { + isIterative = true; + } + + @Override + public void startNextSuperstep() { + checkState(isIterative, "Tried to start next superstep in a non-iterative reader."); + checkState(currentNumberOfEndOfSuperstepEvents == inputGate.getNumberOfInputChannels(), "Tried to start next superstep before reaching end of previous superstep."); + + currentNumberOfEndOfSuperstepEvents = 0; + } + + @Override + public boolean hasReachedEndOfSuperstep() { + if (isIterative) { + return currentNumberOfEndOfSuperstepEvents == inputGate.getNumberOfInputChannels(); + } + + return false; + } + + private boolean incrementEndOfSuperstepEventAndCheck() { + checkState(isIterative, "Tried to increment superstep count in a non-iterative reader."); + checkState(currentNumberOfEndOfSuperstepEvents + 1 <= inputGate.getNumberOfInputChannels(), "Received too many (" + currentNumberOfEndOfSuperstepEvents + ") end of superstep events."); + + return ++currentNumberOfEndOfSuperstepEvents == inputGate.getNumberOfInputChannels(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java index 2ee3256320b41..e70b6eeeaec82 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/AbstractRecordReader.java @@ -19,26 +19,23 @@ package org.apache.flink.runtime.io.network.api.reader; import org.apache.flink.core.io.IOReadableWritable; -import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.serialization.SpillingAdaptiveSpanningRecordDeserializer; -import org.apache.flink.runtime.util.event.EventListener; import java.io.IOException; /** - * A record-oriented runtime result reader, which wraps a {@link BufferReaderBase}. + * A record-oriented reader. *

- * This abstract base class is used by both the mutable and immutable record - * reader. + * This abstract base class is used by both the mutable and immutable record readers. * * @param The type of the record that can be read with this record reader. */ -abstract class AbstractRecordReader implements ReaderBase { - - private final BufferReaderBase reader; +abstract class AbstractRecordReader extends AbstractReader implements ReaderBase { private final RecordDeserializer[] recordDeserializers; @@ -46,11 +43,11 @@ abstract class AbstractRecordReader implements Rea private boolean isFinished; - protected AbstractRecordReader(BufferReaderBase reader) { - this.reader = reader; + protected AbstractRecordReader(InputGate inputGate) { + super(inputGate); // Initialize one deserializer per input channel - this.recordDeserializers = new SpillingAdaptiveSpanningRecordDeserializer[reader.getNumberOfInputChannels()]; + this.recordDeserializers = new SpillingAdaptiveSpanningRecordDeserializer[inputGate.getNumberOfInputChannels()]; for (int i = 0; i < recordDeserializers.length; i++) { recordDeserializers[i] = new SpillingAdaptiveSpanningRecordDeserializer(); } @@ -75,25 +72,23 @@ protected boolean getNextRecord(T target) throws IOException, InterruptedExcepti } } - final Buffer nextBuffer = reader.getNextBufferBlocking(); - final int channelIndex = reader.getChannelIndexOfLastBuffer(); + final BufferOrEvent bufferOrEvent = inputGate.getNextBufferOrEvent(); - if (nextBuffer == null) { - if (reader.isFinished()) { + if (bufferOrEvent.isBuffer()) { + currentRecordDeserializer = recordDeserializers[bufferOrEvent.getChannelIndex()]; + currentRecordDeserializer.setNextBuffer(bufferOrEvent.getBuffer()); + } + else if (handleEvent(bufferOrEvent.getEvent())) { + if (inputGate.isFinished()) { isFinished = true; + return false; } - else if (reader.hasReachedEndOfSuperstep()) { + else if (hasReachedEndOfSuperstep()) { + return false; - } - else { - // More data is coming... - continue; - } + } // else: More data is coming... } - - currentRecordDeserializer = recordDeserializers[channelIndex]; - currentRecordDeserializer.setNextBuffer(nextBuffer); } } @@ -105,34 +100,4 @@ public void clearBuffers() { } } } - - @Override - public void sendTaskEvent(TaskEvent event) throws IOException, InterruptedException { - reader.sendTaskEvent(event); - } - - @Override - public boolean isFinished() { - return reader.isFinished(); - } - - @Override - public void subscribeToTaskEvent(EventListener eventListener, Class eventType) { - reader.subscribeToTaskEvent(eventListener, eventType); - } - - @Override - public void setIterativeReader() { - reader.setIterativeReader(); - } - - @Override - public void startNextSuperstep() { - reader.startNextSuperstep(); - } - - @Override - public boolean hasReachedEndOfSuperstep() { - return reader.hasReachedEndOfSuperstep(); - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReader.java index fca27faede1b5..ca5960959a2ec 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReader.java @@ -18,489 +18,33 @@ package org.apache.flink.runtime.io.network.api.reader; -import com.google.common.collect.Maps; -import org.apache.flink.runtime.deployment.PartitionConsumerDeploymentDescriptor; -import org.apache.flink.runtime.deployment.PartitionInfo; -import org.apache.flink.runtime.deployment.PartitionInfo.PartitionLocation; -import org.apache.flink.runtime.event.task.AbstractEvent; -import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.execution.RuntimeEnvironment; -import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.ConnectionManager; -import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.RemoteAddress; -import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; -import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; -import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.BufferPool; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; -import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionProvider; -import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.UnknownInputChannel; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; -import org.apache.flink.runtime.util.event.EventListener; -import org.apache.flink.runtime.util.event.EventNotificationHandler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -public final class BufferReader implements BufferReaderBase { - - private static final Logger LOG = LoggerFactory.getLogger(BufferReader.class); - - private final Object requestLock = new Object(); - - private final RuntimeEnvironment environment; - - private final NetworkEnvironment networkEnvironment; - - private final EventNotificationHandler taskEventHandler = new EventNotificationHandler(); - - private final IntermediateDataSetID consumedResultId; - - private final int totalNumberOfInputChannels; - - private final int queueToRequest; - - private final Map inputChannels; - - private BufferPool bufferPool; - - private boolean isReleased; - - private boolean isTaskEvent; - - // ------------------------------------------------------------------------ - - private final BlockingQueue inputChannelsWithData = new LinkedBlockingQueue(); - - private final AtomicReference> readerListener = new AtomicReference>(null); - - private final List pendingEvents = new ArrayList(); - - private int numberOfUninitializedChannels; - - // ------------------------------------------------------------------------ - - private boolean isIterativeReader; - - private int currentNumEndOfSuperstepEvents; - - private int channelIndexOfLastReadBuffer = -1; - - private boolean hasRequestedPartitions = false; - - public BufferReader(RuntimeEnvironment environment, NetworkEnvironment networkEnvironment, IntermediateDataSetID consumedResultId, int numberOfInputChannels, int queueToRequest) { - - this.consumedResultId = checkNotNull(consumedResultId); - // Note: the environment is not fully initialized yet - this.environment = checkNotNull(environment); - - this.networkEnvironment = networkEnvironment; - - checkArgument(numberOfInputChannels >= 0); - this.totalNumberOfInputChannels = numberOfInputChannels; - - checkArgument(queueToRequest >= 0); - this.queueToRequest = queueToRequest; - - this.inputChannels = Maps.newHashMapWithExpectedSize(numberOfInputChannels); - } - - // ------------------------------------------------------------------------ - // Properties - // ------------------------------------------------------------------------ - - public void setBufferPool(BufferPool bufferPool) { - checkArgument(bufferPool.getNumberOfRequiredMemorySegments() == totalNumberOfInputChannels, "Buffer pool has not enough buffers for this reader."); - checkState(this.bufferPool == null, "Buffer pool has already been set for reader."); - - this.bufferPool = checkNotNull(bufferPool); - } - - public IntermediateDataSetID getConsumedResultId() { - return consumedResultId; - } - - public String getTaskNameWithSubtasks() { - return environment.getTaskNameWithSubtasks(); - } - - public IntermediateResultPartitionProvider getIntermediateResultPartitionProvider() { - return networkEnvironment.getPartitionManager(); - } - - public TaskEventDispatcher getTaskEventDispatcher() { - return networkEnvironment.getTaskEventDispatcher(); - } - - public ConnectionManager getConnectionManager() { - return networkEnvironment.getConnectionManager(); - } - - /** - * Returns the total number of input channels for this reader. - *

- * Note: This number might be smaller the current number of input channels - * of the reader as channels are possibly updated during runtime. - */ - @Override - public int getNumberOfInputChannels() { - return totalNumberOfInputChannels; - } - - public BufferProvider getBufferProvider() { - return bufferPool; - } - - public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) { - synchronized (requestLock) { - if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null && - inputChannel.getClass() == UnknownInputChannel.class) { - - numberOfUninitializedChannels++; - } - } - } - - public void updateInputChannel(PartitionInfo partitionInfo) throws IOException { - synchronized (requestLock) { - if (isReleased) { - // There was a race with a task failure/cancel - return; - } - - final IntermediateResultPartitionID partitionId = partitionInfo.getPartitionId(); - - InputChannel current = inputChannels.get(partitionId); - - if (current.getClass() == UnknownInputChannel.class) { - UnknownInputChannel unknownChannel = (UnknownInputChannel) current; - - InputChannel newChannel; - - if (partitionInfo.getProducerLocation() == PartitionLocation.REMOTE) { - newChannel = unknownChannel.toRemoteInputChannel(partitionInfo.getProducerAddress()); - } - else if (partitionInfo.getProducerLocation() == PartitionLocation.LOCAL) { - newChannel = unknownChannel.toLocalInputChannel(); - } - else { - throw new IllegalStateException("Tried to update unknown channel with unknown channel."); - } - - inputChannels.put(partitionId, newChannel); - - - newChannel.requestIntermediateResultPartition(queueToRequest); - - for (TaskEvent event : pendingEvents) { - newChannel.sendTaskEvent(event); - } - - if (--numberOfUninitializedChannels == 0) { - pendingEvents.clear(); - } - } - } - } - - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ - - @Override - public void requestPartitionsOnce() throws IOException { - if (!hasRequestedPartitions) { - // Sanity check - if (totalNumberOfInputChannels != inputChannels.size()) { - throw new IllegalStateException("Mismatch between number of total input channels and the currently number of set input channels."); - } - - synchronized (requestLock) { - for (InputChannel inputChannel : inputChannels.values()) { - inputChannel.requestIntermediateResultPartition(queueToRequest); - } - } +/** + * A buffer-oriented reader. + */ +public final class BufferReader extends AbstractReader { - hasRequestedPartitions = true; - } + public BufferReader(InputGate gate) { + super(gate); } - @Override - public Buffer getNextBufferBlocking() throws IOException, InterruptedException { - requestPartitionsOnce(); - + public Buffer getNextBuffer() throws IOException, InterruptedException { while (true) { - if (Thread.interrupted()) { - throw new InterruptedException(); - } + final BufferOrEvent bufferOrEvent = inputGate.getNextBufferOrEvent(); - // Possibly block until data is available at one of the input channels - InputChannel currentChannel = null; - while (currentChannel == null) { - currentChannel = inputChannelsWithData.poll(2000, TimeUnit.MILLISECONDS); - } - - isTaskEvent = false; - - final Buffer buffer = currentChannel.getNextBuffer(); - - if (buffer == null) { - throw new IllegalStateException("Bug in reader logic: queried for a buffer although none was available."); - } - - if (buffer.isBuffer()) { - channelIndexOfLastReadBuffer = currentChannel.getChannelIndex(); - return buffer; + if (bufferOrEvent.isBuffer()) { + return bufferOrEvent.getBuffer(); } else { - try { - final AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); - - // ------------------------------------------------------------ - // Runtime events - // ------------------------------------------------------------ - // Note: We can not assume that every channel will be finished - // with an according event. In failure cases or iterations the - // consumer task finishes earlier and has to release all - // resources. - // ------------------------------------------------------------ - if (event.getClass() == EndOfPartitionEvent.class) { - currentChannel.releaseAllResources(); - - return null; - } - else if (event.getClass() == EndOfSuperstepEvent.class) { - incrementEndOfSuperstepEventAndCheck(); - - return null; - } - // ------------------------------------------------------------ - // Task events (user) - // ------------------------------------------------------------ - else if (event instanceof TaskEvent) { - taskEventHandler.publish((TaskEvent) event); - - isTaskEvent = true; - - return null; - } - else { - throw new IllegalStateException("Received unexpected event " + event + " from input channel " + currentChannel + "."); - } - } - catch (Throwable t) { - throw new IOException("Error while reading event: " + t.getMessage(), t); - } - finally { - buffer.recycle(); - } - } - } - } - - @Override - public Buffer getNextBuffer(Buffer exchangeBuffer) { - throw new UnsupportedOperationException("Buffer exchange when reading data is not yet supported."); - } - - @Override - public int getChannelIndexOfLastBuffer() { - return channelIndexOfLastReadBuffer; - } - - @Override - public boolean isTaskEvent() { - return isTaskEvent; - } - - @Override - public boolean isFinished() { - synchronized (requestLock) { - for (InputChannel inputChannel : inputChannels.values()) { - if (!inputChannel.isReleased()) { - return false; - } - } - } - - return true; - } - - public void releaseAllResources() throws IOException { - synchronized (requestLock) { - if (!isReleased) { - try { - for (InputChannel inputChannel : inputChannels.values()) { - try { - inputChannel.releaseAllResources(); - } - catch (IOException e) { - LOG.warn("Error during release of channel resources: " + e.getMessage(), e); - } - } - - // The buffer pool can actually be destroyed immediately after the - // reader received all of the data from the input channels. - if (bufferPool != null) { - bufferPool.destroy(); - } - } - finally { - isReleased = true; + if (handleEvent(bufferOrEvent.getEvent())) { + return null; } } } } - - // ------------------------------------------------------------------------ - // Channel notifications - // ------------------------------------------------------------------------ - - public void onAvailableInputChannel(InputChannel inputChannel) { - inputChannelsWithData.add(inputChannel); - - if (readerListener.get() != null) { - readerListener.get().onEvent(this); - } - } - - @Override - public void subscribeToReader(EventListener listener) { - if (!this.readerListener.compareAndSet(null, listener)) { - throw new IllegalStateException(listener + " is already registered as a record availability listener"); - } - } - - // ------------------------------------------------------------------------ - // Task events - // ------------------------------------------------------------------------ - - @Override - public void sendTaskEvent(TaskEvent event) throws IOException, InterruptedException { - // This can be improved by just serializing the event once for all - // remote input channels. - synchronized (requestLock) { - for (InputChannel inputChannel : inputChannels.values()) { - inputChannel.sendTaskEvent(event); - } - - if (numberOfUninitializedChannels > 0) { - pendingEvents.add(event); - } - } - } - - @Override - public void subscribeToTaskEvent(EventListener listener, Class eventType) { - taskEventHandler.subscribe(listener, eventType); - } - - // ------------------------------------------------------------------------ - // Iteration end of superstep events - // ------------------------------------------------------------------------ - - @Override - public void setIterativeReader() { - isIterativeReader = true; - } - - @Override - public void startNextSuperstep() { - checkState(isIterativeReader, "Tried to start next superstep in a non-iterative reader."); - checkState(currentNumEndOfSuperstepEvents == totalNumberOfInputChannels, - "Tried to start next superstep before reaching end of previous superstep."); - - currentNumEndOfSuperstepEvents = 0; - } - - @Override - public boolean hasReachedEndOfSuperstep() { - return currentNumEndOfSuperstepEvents == totalNumberOfInputChannels; - } - - private boolean incrementEndOfSuperstepEventAndCheck() { - checkState(isIterativeReader, "Received end of superstep event in a non-iterative reader."); - - currentNumEndOfSuperstepEvents++; - - checkState(currentNumEndOfSuperstepEvents <= totalNumberOfInputChannels, - "Received too many (" + currentNumEndOfSuperstepEvents + ") end of superstep events."); - - return currentNumEndOfSuperstepEvents == totalNumberOfInputChannels; - } - - // ------------------------------------------------------------------------ - - @Override - public String toString() { - return String.format("BufferReader %s [task: %s, current/total number of input channels: %d/%d]", - consumedResultId, getTaskNameWithSubtasks(), inputChannels.size(), totalNumberOfInputChannels); - } - - public static BufferReader create(RuntimeEnvironment runtimeEnvironment, NetworkEnvironment networkEnvironment, PartitionConsumerDeploymentDescriptor desc) { - // The consumed intermediate data set (all partitions are part of this data set) - final IntermediateDataSetID resultId = desc.getResultId(); - - // The queue to request from each consumed partition - final int queueIndex = desc.getQueueIndex(); - - // There is one input channel for each consumed partition - final PartitionInfo[] partitions = desc.getPartitions(); - final int numberOfInputChannels = partitions.length; - - final BufferReader reader = new BufferReader(runtimeEnvironment, networkEnvironment, resultId, numberOfInputChannels, queueIndex); - - // Create input channels - final InputChannel[] inputChannels = new InputChannel[numberOfInputChannels]; - - int channelIndex = 0; - - for (PartitionInfo partition : partitions) { - final ExecutionAttemptID producerExecutionId = partition.getProducerExecutionId(); - final IntermediateResultPartitionID partitionId = partition.getPartitionId(); - - final PartitionLocation producerLocation = partition.getProducerLocation(); - - switch (producerLocation) { - case LOCAL: - inputChannels[channelIndex] = new LocalInputChannel(channelIndex, producerExecutionId, partitionId, reader); - break; - - case REMOTE: - final RemoteAddress producerAddress = checkNotNull(partition.getProducerAddress(), "Missing producer address for remote intermediate result partition."); - - inputChannels[channelIndex] = new RemoteInputChannel(channelIndex, producerExecutionId, partitionId, reader, producerAddress); - break; - - case UNKNOWN: - inputChannels[channelIndex] = new UnknownInputChannel(channelIndex, producerExecutionId, partitionId, reader); - break; - } - - reader.setInputChannel(partitionId, inputChannels[channelIndex]); - - channelIndex++; - } - - return reader; - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderBase.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderBase.java deleted file mode 100644 index d1dbefdfe55a9..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderBase.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.api.reader; - -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; -import org.apache.flink.runtime.util.event.EventListener; - -import java.io.IOException; - -/** - * A buffer-oriented runtime result reader. - *

- * {@link BufferReaderBase} is the runtime API for consuming results. Events - * are handled by the reader and users can query for buffers with - * {@link #getNextBufferBlocking()} or {@link #getNextBuffer(Buffer)}. - *

- * Important: If {@link #getNextBufferBlocking()} is used, it is - * necessary to release the returned buffers with {@link Buffer#recycle()} - * after they are consumed. - */ -public interface BufferReaderBase extends ReaderBase { - - /** - * Returns the next queued {@link Buffer} from one of the {@link RemoteInputChannel} - * instances attached to this reader. The are no ordering guarantees with - * respect to which channel is queried for data. - *

- * Important: it is necessary to release buffers, which - * are returned by the reader via {@link Buffer#recycle()}, because they - * are a pooled resource. If not recycled, the network stack will run out - * of buffers and deadlock. - * - * @see #getChannelIndexOfLastBuffer() - */ - Buffer getNextBufferBlocking() throws IOException, InterruptedException; - - /** - * {@link #getNextBufferBlocking()} requires the user to quickly recycle the - * returned buffer. For a fully buffer-oriented runtime, we need to - * support a variant of this method, which allows buffers to be exchanged - * in order to save unnecessary memory copies between buffer pools. - *

- * Currently this is not a problem, because the only "users" of the buffer- - * oriented API are the record-oriented readers, which immediately - * deserialize the buffer and recycle it. - */ - Buffer getNextBuffer(Buffer exchangeBuffer) throws IOException, InterruptedException; - - /** - * Returns a channel index for the last {@link Buffer} instance returned by - * {@link #getNextBufferBlocking()} or {@link #getNextBuffer(Buffer)}. - *

- * The returned index is guaranteed to be the same for all buffers read by - * the same {@link RemoteInputChannel} instance. This is useful when data spans - * multiple buffers returned by this reader. - *

- * Initially returns -1 and if multiple readers are unioned, - * the local channel indexes are mapped to the sequence from 0 to n-1. - */ - int getChannelIndexOfLastBuffer(); - - /** - * Returns the total number of {@link InputChannel} instances, from which this - * reader gets its data. - */ - int getNumberOfInputChannels(); - - boolean isTaskEvent(); - - void subscribeToReader(EventListener listener); - - void requestPartitionsOnce() throws IOException; - -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/MutableRecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/MutableRecordReader.java index 75d4f21fd7941..d7cc7e92821f5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/MutableRecordReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/MutableRecordReader.java @@ -19,13 +19,14 @@ package org.apache.flink.runtime.io.network.api.reader; import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import java.io.IOException; public class MutableRecordReader extends AbstractRecordReader implements MutableReader { - public MutableRecordReader(BufferReaderBase reader) { - super(reader); + public MutableRecordReader(InputGate inputGate) { + super(inputGate); } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/ReaderBase.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/ReaderBase.java index bb6ec44440a27..2a0a6df9664b2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/ReaderBase.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/ReaderBase.java @@ -24,23 +24,22 @@ import org.apache.flink.runtime.util.event.EventListener; /** - * The basic API every reader (both buffer- and record-oriented) has to support. + * The basic API for every reader. */ public interface ReaderBase { - // ------------------------------------------------------------------------ - // Properties - // ------------------------------------------------------------------------ - + /** + * Returns whether the reader has consumed the input. + */ boolean isFinished(); // ------------------------------------------------------------------------ - // Events + // Task events // ------------------------------------------------------------------------ - void subscribeToTaskEvent(EventListener eventListener, Class eventType); + void sendTaskEvent(TaskEvent event) throws IOException; - void sendTaskEvent(TaskEvent event) throws IOException, InterruptedException; + void registerTaskEventListener(EventListener listener, Class eventType); // ------------------------------------------------------------------------ // Iterations diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/RecordReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/RecordReader.java index db992a5916dde..b1395e3960570 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/RecordReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/RecordReader.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.api.reader; import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import java.io.IOException; @@ -28,8 +29,8 @@ public class RecordReader extends AbstractRecordRe private T currentRecord; - public RecordReader(BufferReaderBase reader, Class recordType) { - super(reader); + public RecordReader(InputGate inputGate, Class recordType) { + super(inputGate); this.recordType = recordType; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReader.java deleted file mode 100644 index 241e212eedf16..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReader.java +++ /dev/null @@ -1,298 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.api.reader; - -import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.util.event.EventListener; - -import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -/** - * A buffer-oriented reader, which unions multiple {@link BufferReader} - * instances. - */ -public class UnionBufferReader implements BufferReaderBase { - - private final BufferReaderBase[] readers; - - private final DataAvailabilityListener dataAvailabilityListener; - - // Set of readers, which are not closed yet - private final Set remainingReaders; - - // Logical channel index offset for each reader - private final Map readerToIndexOffsetMap = new HashMap(); - - private int totalNumInputChannels; - - private BufferReaderBase currentReader; - - private int currentReaderChannelIndexOffset; - - private int channelIndexOfLastReadBuffer = -1; - - private boolean isIterative; - - private boolean hasRequestedPartitions; - - private boolean isTaskEvent; - - public UnionBufferReader(BufferReaderBase... readers) { - checkNotNull(readers); - checkArgument(readers.length >= 2, "Union buffer reader must be initialized with at least two individual buffer readers"); - - this.readers = readers; - this.remainingReaders = new HashSet(readers.length + 1, 1.0F); - - this.dataAvailabilityListener = new DataAvailabilityListener(this); - - int currentChannelIndexOffset = 0; - - for (int i = 0; i < readers.length; i++) { - BufferReaderBase reader = readers[i]; - - reader.subscribeToReader(dataAvailabilityListener); - - remainingReaders.add(reader); - readerToIndexOffsetMap.put(reader, currentChannelIndexOffset); - - totalNumInputChannels += reader.getNumberOfInputChannels(); - currentChannelIndexOffset += reader.getNumberOfInputChannels(); - } - } - - @Override - public void requestPartitionsOnce() throws IOException { - if (!hasRequestedPartitions) { - for (BufferReaderBase reader : readers) { - reader.requestPartitionsOnce(); - } - - hasRequestedPartitions = true; - } - } - - - @Override - public Buffer getNextBufferBlocking() throws IOException, InterruptedException { - requestPartitionsOnce(); - - do { - if (currentReader == null) { - // Finished when all readers are finished - if (isFinished()) { - dataAvailabilityListener.clear(); - return null; - } - // Finished with superstep when all readers finished superstep - else if (isIterative && remainingReaders.isEmpty()) { - resetRemainingReaders(); - return null; - } - else { - while (true) { - currentReader = dataAvailabilityListener.getNextReaderBlocking(); - currentReaderChannelIndexOffset = readerToIndexOffsetMap.get(currentReader); - - if (isIterative && !remainingReaders.contains(currentReader)) { - // If the current reader already received its end - // of superstep event and notified the union reader - // about newer data *before* all other readers have - // done so, we delay this notifications. - dataAvailabilityListener.addReader(currentReader); - } - else { - break; - } - } - } - } - - Buffer buffer = currentReader.getNextBufferBlocking(); - channelIndexOfLastReadBuffer = currentReaderChannelIndexOffset + currentReader.getChannelIndexOfLastBuffer(); - - isTaskEvent = false; - - if (buffer == null) { - if (currentReader.isFinished() || currentReader.hasReachedEndOfSuperstep()) { - remainingReaders.remove(currentReader); - } - - currentReader = null; - - return null; - } - else { - currentReader = null; - return buffer; - } - } while (true); - } - - @Override - public Buffer getNextBuffer(Buffer exchangeBuffer) throws IOException, InterruptedException { - throw new UnsupportedOperationException("Buffer exchange when reading data is not yet supported."); - } - - @Override - public int getChannelIndexOfLastBuffer() { - return channelIndexOfLastReadBuffer; - } - - @Override - public int getNumberOfInputChannels() { - return totalNumInputChannels; - } - - @Override - public boolean isTaskEvent() { - return isTaskEvent; - } - - @Override - public void subscribeToReader(EventListener listener) { - dataAvailabilityListener.registerListener(listener); - } - - @Override - public boolean isFinished() { - for (BufferReaderBase reader : readers) { - if (!reader.isFinished()) { - return false; - } - } - - return true; - } - - private void resetRemainingReaders() { - checkState(isIterative, "Tried to reset remaining reader with non-iterative reader."); - checkState(remainingReaders.isEmpty(), "Tried to reset remaining readers, but there are some remaining readers."); - for (BufferReaderBase reader : readers) { - remainingReaders.add(reader); - } - } - - // ------------------------------------------------------------------------ - // TaskEvents - // ------------------------------------------------------------------------ - - @Override - public void subscribeToTaskEvent(EventListener eventListener, Class eventType) { - for (BufferReaderBase reader : readers) { - reader.subscribeToTaskEvent(eventListener, eventType); - } - } - - @Override - public void sendTaskEvent(TaskEvent event) throws IOException, InterruptedException { - for (BufferReaderBase reader : readers) { - reader.sendTaskEvent(event); - } - } - - // ------------------------------------------------------------------------ - // Iteration end of superstep events - // ------------------------------------------------------------------------ - - @Override - public void setIterativeReader() { - isIterative = true; - - for (BufferReaderBase reader : readers) { - reader.setIterativeReader(); - } - } - - @Override - public void startNextSuperstep() { - for (BufferReaderBase reader : readers) { - reader.startNextSuperstep(); - } - } - - @Override - public boolean hasReachedEndOfSuperstep() { - for (BufferReaderBase reader : readers) { - if (!reader.hasReachedEndOfSuperstep()) { - return false; - } - } - - return true; - } - - // ------------------------------------------------------------------------ - // Data availability notifications - // ------------------------------------------------------------------------ - - private static class DataAvailabilityListener implements EventListener { - - private final UnionBufferReader unionReader; - - private final BlockingQueue readersWithData = new LinkedBlockingQueue(); - - private volatile EventListener registeredListener; - - private DataAvailabilityListener(UnionBufferReader unionReader) { - this.unionReader = unionReader; - } - - @Override - public void onEvent(BufferReaderBase reader) { - readersWithData.add(reader); - - if (registeredListener != null) { - registeredListener.onEvent(unionReader); - } - } - - BufferReaderBase getNextReaderBlocking() throws InterruptedException { - return readersWithData.take(); - } - - void addReader(BufferReaderBase reader) { - readersWithData.add(reader); - } - - void clear() { - readersWithData.clear(); - } - - void registerListener(EventListener listener) { - if (registeredListener == null) { - registeredListener = listener; - } - else { - throw new IllegalStateException("Already registered listener."); - } - } - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BufferWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BufferWriter.java index 6cb1831f40921..b9c6d33631197 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BufferWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/BufferWriter.java @@ -27,7 +27,7 @@ import org.apache.flink.runtime.io.network.partition.IntermediateResultPartition; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.util.event.EventListener; -import org.apache.flink.runtime.util.event.EventNotificationHandler; +import org.apache.flink.runtime.io.network.api.TaskEventHandler; import java.io.IOException; @@ -44,7 +44,7 @@ public final class BufferWriter implements EventListener { private final IntermediateResultPartition partition; - private final EventNotificationHandler taskEventHandler = new EventNotificationHandler(); + private final TaskEventHandler taskEventHandler = new TaskEventHandler(); public BufferWriter(IntermediateResultPartition partition) { this.partition = partition; @@ -104,7 +104,7 @@ public boolean isFinished() { // Event handling // ------------------------------------------------------------------------ - public EventNotificationHandler getTaskEventHandler() { + public TaskEventHandler getTaskEventHandler() { return taskEventHandler; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java new file mode 100644 index 0000000000000..032772c2c2daf --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/BufferOrEvent.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.event.task.AbstractEvent; +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Either type for {@link Buffer} or {@link AbstractEvent} instances tagged with the channel index, + * from which they were received. + */ +public class BufferOrEvent { + + private final Buffer buffer; + + private final AbstractEvent event; + + private int channelIndex; + + public BufferOrEvent(Buffer buffer, int channelIndex) { + this.buffer = buffer; + this.event = null; + this.channelIndex = channelIndex; + } + + public BufferOrEvent(AbstractEvent event, int channelIndex) { + this.buffer = null; + this.event = event; + this.channelIndex = channelIndex; + } + + public boolean isBuffer() { + return buffer != null; + } + + public boolean isEvent() { + return event != null; + } + + public Buffer getBuffer() { + return buffer; + } + + public AbstractEvent getEvent() { + return event; + } + + public int getChannelIndex() { + return channelIndex; + } + + public void setChannelIndex(int channelIndex) { + checkArgument(channelIndex >= 0); + this.channelIndex = channelIndex; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java index ea8f459db1f43..31b67ca659a56 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannel.java @@ -20,7 +20,6 @@ import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.partition.queue.IntermediateResultPartitionQueue; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; @@ -28,7 +27,7 @@ import java.io.IOException; /** - * An input channel is the consumer of a single queue of an {@link IntermediateResultPartitionQueue}. + * An input channel is the consumer of a single subpartition of an {@link IntermediateResultPartitionQueue}. *

* For each channel, the consumption life cycle is as follows: *

    @@ -45,13 +44,13 @@ public abstract class InputChannel { protected final IntermediateResultPartitionID partitionId; - protected final BufferReader reader; + protected final SingleInputGate inputGate; - protected InputChannel(int channelIndex, ExecutionAttemptID producerExecutionId, IntermediateResultPartitionID partitionId, BufferReader reader) { + protected InputChannel(SingleInputGate inputGate, int channelIndex, ExecutionAttemptID producerExecutionId, IntermediateResultPartitionID partitionId) { + this.inputGate = inputGate; this.channelIndex = channelIndex; this.producerExecutionId = producerExecutionId; this.partitionId = partitionId; - this.reader = reader; } // ------------------------------------------------------------------------ @@ -76,11 +75,10 @@ public String toString() { } /** - * Notifies the {@link BufferReader}, which consumes this input channel - * about an available {@link Buffer} instance. + * Notifies the owning {@link SingleInputGate} about an available {@link Buffer} instance. */ - protected void notifyReaderAboutAvailableBuffer() { - reader.onAvailableInputChannel(this); + protected void notifyAvailableBuffer() { + inputGate.onAvailableBuffer(this); } // ------------------------------------------------------------------------ @@ -97,7 +95,7 @@ protected void notifyReaderAboutAvailableBuffer() { public abstract void requestIntermediateResultPartition(int queueIndex) throws IOException; /** - * Returns the next buffer from the consumed queue. + * Returns the next buffer from the consumed subpartition. */ public abstract Buffer getNextBuffer() throws IOException; @@ -106,10 +104,12 @@ protected void notifyReaderAboutAvailableBuffer() { // ------------------------------------------------------------------------ /** - * Sends a {@link TaskEvent} back to the partition producer. + * Sends a {@link TaskEvent} back to the task producing the consumed result partition. *

    - * Important: This only works if the producer task is - * running at the same time. + * Important: The producing task has to be running to receive backwards events. + * This means that the result type needs to be pipelined and the task logic has to ensure that + * the producer will wait for all backwards events. Otherwise, this will lead to an Exception + * at runtime. */ public abstract void sendTaskEvent(TaskEvent event) throws IOException; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java new file mode 100644 index 0000000000000..8d28084fa4cb7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.util.event.EventListener; + +import java.io.IOException; + +public interface InputGate { + + public int getNumberOfInputChannels(); + + public boolean isFinished(); + + public void requestPartitions() throws IOException; + + public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException; + + public void sendTaskEvent(TaskEvent event) throws IOException; + + public void registerListener(EventListener listener); + +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java index 150aaea7dc9c0..8b5823435060d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java @@ -21,8 +21,9 @@ import com.google.common.base.Optional; import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionManager; import org.apache.flink.runtime.io.network.partition.queue.IntermediateResultPartitionQueueIterator; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.util.event.NotificationListener; @@ -31,23 +32,37 @@ import java.io.IOException; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; /** - * An input channel, which requests a local partition queue. + * An input channel, which requests a local subpartition. */ public class LocalInputChannel extends InputChannel implements NotificationListener { private static final Logger LOG = LoggerFactory.getLogger(LocalInputChannel.class); + private final IntermediateResultPartitionManager partitionManager; + + private final TaskEventDispatcher taskEventDispatcher; + private IntermediateResultPartitionQueueIterator queueIterator; private boolean isReleased; - private Buffer lookAhead; + private volatile Buffer lookAhead; + + public LocalInputChannel( + SingleInputGate gate, int channelIndex, + ExecutionAttemptID producerExecutionId, + IntermediateResultPartitionID partitionId, + IntermediateResultPartitionManager partitionManager, + TaskEventDispatcher taskEventDispatcher) { + + super(gate, channelIndex, producerExecutionId, partitionId); - public LocalInputChannel(int channelIndex, ExecutionAttemptID producerExecutionId, IntermediateResultPartitionID partitionId, BufferReader reader) { - super(channelIndex, producerExecutionId, partitionId, reader); + this.partitionManager = checkNotNull(partitionManager); + this.taskEventDispatcher = checkNotNull(taskEventDispatcher); } // ------------------------------------------------------------------------ @@ -61,8 +76,8 @@ public void requestIntermediateResultPartition(int queueIndex) throws IOExceptio LOG.debug("Requesting queue {} from LOCAL partition {}.", partitionId, queueIndex); } - queueIterator = reader.getIntermediateResultPartitionProvider() - .getIntermediateResultPartitionIterator(producerExecutionId, partitionId, queueIndex, Optional.of(reader.getBufferProvider())); + queueIterator = partitionManager.getIntermediateResultPartitionIterator( + producerExecutionId, partitionId, queueIndex, Optional.of(inputGate.getBufferProvider())); getNextLookAhead(); } @@ -93,7 +108,7 @@ public Buffer getNextBuffer() throws IOException { public void sendTaskEvent(TaskEvent event) throws IOException { checkState(queueIterator != null, "Tried to send task event to producer before requesting a queue."); - if (!reader.getTaskEventDispatcher().publish(producerExecutionId, partitionId, event)) { + if (!taskEventDispatcher.publish(producerExecutionId, partitionId, event)) { throw new IOException("Error while publishing event " + event + " to producer. The producer could not be found."); } } @@ -139,7 +154,7 @@ public String toString() { @Override public void onNotification() { - notifyReaderAboutAvailableBuffer(); + notifyAvailableBuffer(); } // ------------------------------------------------------------------------ @@ -149,7 +164,7 @@ private void getNextLookAhead() throws IOException { lookAhead = queueIterator.getNextBuffer(); if (lookAhead != null) { - notifyReaderAboutAvailableBuffer(); + notifyAvailableBuffer(); break; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 616a8a5decd3e..002c90e053025 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -20,8 +20,8 @@ import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.RemoteAddress; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.netty.PartitionRequestClient; @@ -35,6 +35,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; /** @@ -58,21 +59,25 @@ public class RemoteInputChannel extends InputChannel { private int expectedSequenceNumber = 0; + private ConnectionManager connectionManager; + public RemoteInputChannel( + SingleInputGate gate, int channelIndex, ExecutionAttemptID producerExecutionId, IntermediateResultPartitionID partitionId, - BufferReader reader, - RemoteAddress producerAddress) { + RemoteAddress producerAddress, + ConnectionManager connectionManager) { - super(channelIndex, producerExecutionId, partitionId, reader); + super(gate, channelIndex, producerExecutionId, partitionId); /** * This ID is used by the {@link PartitionRequestClient} to distinguish * between receivers, which share the same TCP connection. */ this.id = new InputChannelID(); - this.producerAddress = producerAddress; + this.producerAddress = checkNotNull(producerAddress); + this.connectionManager = checkNotNull(connectionManager); } // ------------------------------------------------------------------------ @@ -86,7 +91,7 @@ public void requestIntermediateResultPartition(int queueIndex) throws IOExceptio LOG.debug("Requesting queue {} from REMOTE partition {}.", partitionId, queueIndex); } - partitionRequestClient = reader.getConnectionManager().createPartitionRequestClient(producerAddress); + partitionRequestClient = connectionManager.createPartitionRequestClient(producerAddress); partitionRequestClient.requestIntermediateResultPartition(producerExecutionId, partitionId, queueIndex, this); } @@ -171,7 +176,7 @@ public BufferProvider getBufferProvider() throws IOException { return null; } - return reader.getBufferProvider(); + return inputGate.getBufferProvider(); } public void onBuffer(Buffer buffer, int sequenceNumber) { @@ -184,7 +189,7 @@ public void onBuffer(Buffer buffer, int sequenceNumber) { receivedBuffers.add(buffer); expectedSequenceNumber++; - notifyReaderAboutAvailableBuffer(); + notifyAvailableBuffer(); success = true; @@ -205,7 +210,7 @@ public void onBuffer(Buffer buffer, int sequenceNumber) { public void onError(Throwable error) { if (ioError.compareAndSet(null, error instanceof IOException ? (IOException) error : new IOException(error))) { - notifyReaderAboutAvailableBuffer(); + notifyAvailableBuffer(); } } @@ -215,8 +220,8 @@ private void checkIoError() throws IOException { IOException error = ioError.get(); if (error != null) { - throw new IOException(String.format("%s at remote input channel of task '%s': %s].", - error.getClass().getName(), reader.getTaskNameWithSubtasks(), error.getMessage())); + throw new IOException(String.format("%s at remote input channel: %s].", + error.getClass().getName(), error.getMessage())); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java new file mode 100644 index 0000000000000..a8a92d64a5a8d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import com.google.common.collect.Maps; +import org.apache.flink.runtime.deployment.PartitionConsumerDeploymentDescriptor; +import org.apache.flink.runtime.deployment.PartitionInfo; +import org.apache.flink.runtime.event.task.AbstractEvent; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.RemoteAddress; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.util.event.EventListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingQueue; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +/** + * An input gate consumes one or more partitions of a single produced intermediate result. + *

    + * Each intermediate result is partitioned over its producing parallel subtasks; each of these + * partitions is furthermore partitioned into one or more subpartitions. + *

    + * As an example, consider a map-reduce program, where the map operator produces data and the reduce + * operator consumes the produced data. + *

    + * +-----+              +---------------------+              +--------+
    + * | Map | = produce => | Intermediate Result | <= consume = | Reduce |
    + * +-----+              +---------------------+              +--------+
    + * 
    + * When deploying such a program in parallel, the intermediate result will be partitioned over its + * producing parallel subtasks; each of these partitions is furthermore partitioned into one or more + * subpartitions. + *
    + *                            Intermediate result
    + *               +-----------------------------------------+
    + *               |                      +----------------+ |              +-----------------------+
    + * +-------+     | +-------------+  +=> | Subpartition 1 | | <=======+=== | Input Gate | Reduce 1 |
    + * | Map 1 | ==> | | Partition 1 | =|   +----------------+ |         |    +-----------------------+
    + * +-------+     | +-------------+  +=> | Subpartition 2 | | <==+    |
    + *               |                      +----------------+ |    |    | Subpartition request
    + *               |                                         |    |    |
    + *               |                      +----------------+ |    |    |
    + * +-------+     | +-------------+  +=> | Subpartition 1 | | <==+====+
    + * | Map 2 | ==> | | Partition 2 | =|   +----------------+ |    |         +-----------------------+
    + * +-------+     | +-------------+  +=> | Subpartition 2 | | <==+======== | Input Gate | Reduce 2 |
    + *               |                      +----------------+ |              +-----------------------+
    + *               +-----------------------------------------+
    + * 
    + * In the above example, two map subtasks produce the intermediate result in parallel, resulting + * in two partitions (Partition 1 and 2). Each of these partitions is subpartitioned into two + * subpartitions -- one for each parallel reduce subtask. + *

    + */ +public class SingleInputGate implements InputGate { + + private static final Logger LOG = LoggerFactory.getLogger(SingleInputGate.class); + + /** Lock object to guard partition requests and runtime channel updates. */ + private final Object requestLock = new Object(); + + /** + * The ID of the consumed intermediate result. Each input gate consumes partitions of the + * intermediate result specified by this ID. This ID also identifies the input gate at the + * consuming task. + */ + private final IntermediateDataSetID consumedResultId; + + /** + * The index of the consumed subpartition of each consumed partition. This index depends on the + * distribution pattern and both subtask indices of the producing and consuming task. + */ + private final int consumedSubpartitionIndex; + + /** The number of input channels (equivalent to the number of consumed partitions). */ + private final int numberOfInputChannels; + + /** + * Input channels. There is a one input channel for each consumed intermediate result partition. + * We store this in a map for runtime updates of single channels. + */ + private final Map inputChannels; + + /** Channels, which notified this input gate about available data. */ + private final BlockingQueue inputChannelsWithData = new LinkedBlockingQueue(); + + /** + * Buffer pool for incoming buffers. Incoming data from remote channels is copied to buffers + * from this pool. + */ + private BufferPool bufferPool; + + /** Flag indicating whether partitions have been requested. */ + private boolean requestedPartitionsFlag; + + /** Flag indicating whether all resources have been released. */ + private boolean releasedResourcesFlag; + + /** Registered listener to forward buffer notifications to. */ + private final List> registeredListeners = new CopyOnWriteArrayList>(); + + private final List pendingEvents = new ArrayList(); + + private int numberOfUninitializedChannels; + + public SingleInputGate(IntermediateDataSetID consumedResultId, int consumedSubpartitionIndex, int numberOfInputChannels) { + this.consumedResultId = checkNotNull(consumedResultId); + + checkArgument(consumedSubpartitionIndex >= 0); + this.consumedSubpartitionIndex = consumedSubpartitionIndex; + + checkArgument(numberOfInputChannels > 0); + this.numberOfInputChannels = numberOfInputChannels; + + this.inputChannels = Maps.newHashMapWithExpectedSize(numberOfInputChannels); + } + + // ------------------------------------------------------------------------ + // Properties + // ------------------------------------------------------------------------ + + @Override + public int getNumberOfInputChannels() { + return numberOfInputChannels; + } + + public IntermediateDataSetID getConsumedResultId() { + return consumedResultId; + } + + BufferProvider getBufferProvider() { + return bufferPool; + } + + // ------------------------------------------------------------------------ + // Setup/Life-cycle + // ------------------------------------------------------------------------ + + public void setBufferPool(BufferPool bufferPool) { + // Sanity checks + checkArgument(numberOfInputChannels == bufferPool.getNumberOfRequiredMemorySegments(), + "Bug in input gate setup logic: buffer pool has not enough guaranteed buffers " + + "for this input gate. Input gates require at least as many buffers as " + + "there are input channels."); + + checkState(this.bufferPool == null, "Bug in input gate setup logic: buffer pool has" + + "already been set for this input gate."); + + this.bufferPool = checkNotNull(bufferPool); + } + + public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) { + synchronized (requestLock) { + if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null && inputChannel.getClass() == UnknownInputChannel.class) { + + numberOfUninitializedChannels++; + } + } + } + + public void updateInputChannel(PartitionInfo partitionInfo) throws IOException { + synchronized (requestLock) { + if (releasedResourcesFlag) { + // There was a race with a task failure/cancel + return; + } + + final IntermediateResultPartitionID partitionId = partitionInfo.getPartitionId(); + + InputChannel current = inputChannels.get(partitionId); + + if (current.getClass() == UnknownInputChannel.class) { + UnknownInputChannel unknownChannel = (UnknownInputChannel) current; + + InputChannel newChannel; + + if (partitionInfo.getProducerLocation() == PartitionInfo.PartitionLocation.REMOTE) { + newChannel = unknownChannel.toRemoteInputChannel(partitionInfo.getProducerAddress()); + } + else if (partitionInfo.getProducerLocation() == PartitionInfo.PartitionLocation.LOCAL) { + newChannel = unknownChannel.toLocalInputChannel(); + } + else { + throw new IllegalStateException("Tried to update unknown channel with unknown channel."); + } + + inputChannels.put(partitionId, newChannel); + + newChannel.requestIntermediateResultPartition(consumedSubpartitionIndex); + + for (TaskEvent event : pendingEvents) { + newChannel.sendTaskEvent(event); + } + + if (--numberOfUninitializedChannels == 0) { + pendingEvents.clear(); + } + } + } + } + + public void releaseAllResources() throws IOException { + synchronized (requestLock) { + if (!releasedResourcesFlag) { + try { + for (InputChannel inputChannel : inputChannels.values()) { + try { + inputChannel.releaseAllResources(); + } + catch (IOException e) { + LOG.warn("Error during release of channel resources: " + e.getMessage(), e); + } + } + + // The buffer pool can actually be destroyed immediately after the + // reader received all of the data from the input channels. + if (bufferPool != null) { + bufferPool.destroy(); + } + } + finally { + releasedResourcesFlag = true; + } + } + } + } + + @Override + public boolean isFinished() { + synchronized (requestLock) { + for (InputChannel inputChannel : inputChannels.values()) { + if (!inputChannel.isReleased()) { + return false; + } + } + } + + return true; + } + + @Override + public void requestPartitions() throws IOException { + if (!requestedPartitionsFlag) { + // Sanity check + if (numberOfInputChannels != inputChannels.size()) { + throw new IllegalStateException("Bug in input gate setup logic: mismatch between" + + "number of total input channels and the currently set number of input " + + "channels."); + } + + synchronized (requestLock) { + for (InputChannel inputChannel : inputChannels.values()) { + inputChannel.requestIntermediateResultPartition(consumedSubpartitionIndex); + } + } + + requestedPartitionsFlag = true; + } + } + + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ + + @Override + public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException { + + if (releasedResourcesFlag) { + throw new IllegalStateException("The input has already been consumed. This indicates misuse of the input gate."); + } + + requestPartitions(); + + final InputChannel currentChannel = inputChannelsWithData.take(); + + final Buffer buffer = currentChannel.getNextBuffer(); + + // Sanity check that notifications only happen when data is available + if (buffer == null) { + throw new IllegalStateException("Bug in input gate/channel logic: input gate got" + + "notified by channel about available data, but none was available."); + } + + if (buffer.isBuffer()) { + return new BufferOrEvent(buffer, currentChannel.getChannelIndex()); + } + else { + final AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); + + if (event.getClass() == EndOfPartitionEvent.class) { + currentChannel.releaseAllResources(); + } + + return new BufferOrEvent(event, currentChannel.getChannelIndex()); + } + } + + @Override + public void sendTaskEvent(TaskEvent event) throws IOException { + synchronized (requestLock) { + for (InputChannel inputChannel : inputChannels.values()) { + inputChannel.sendTaskEvent(event); + } + + if (numberOfUninitializedChannels > 0) { + pendingEvents.add(event); + } + } + } + + // ------------------------------------------------------------------------ + // Channel notifications + // ------------------------------------------------------------------------ + + @Override + public void registerListener(EventListener listener) { + registeredListeners.add(checkNotNull(listener)); + } + + public void onAvailableBuffer(InputChannel channel) { + inputChannelsWithData.add(channel); + + for (int i = 0; i < registeredListeners.size(); i++) { + registeredListeners.get(i).onEvent(this); + } + } + + // ------------------------------------------------------------------------ + + public static SingleInputGate create(NetworkEnvironment networkEnvironment, PartitionConsumerDeploymentDescriptor desc) { + // The consumed intermediate data set (all partitions are part of this data set) + final IntermediateDataSetID resultId = desc.getResultId(); + // The queue to request from each consumed partition + final int queueIndex = desc.getQueueIndex(); + // There is one input channel for each consumed partition + final PartitionInfo[] partitions = desc.getPartitions(); + final int numberOfInputChannels = partitions.length; + final SingleInputGate reader = new SingleInputGate(resultId, queueIndex, numberOfInputChannels); + // Create input channels + final InputChannel[] inputChannels = new InputChannel[numberOfInputChannels]; + int channelIndex = 0; + for (PartitionInfo partition : partitions) { + final ExecutionAttemptID producerExecutionId = partition.getProducerExecutionId(); + final IntermediateResultPartitionID partitionId = partition.getPartitionId(); + final PartitionInfo.PartitionLocation producerLocation = partition.getProducerLocation(); + switch (producerLocation) { + case LOCAL: + inputChannels[channelIndex] = new LocalInputChannel(reader, channelIndex, producerExecutionId, partitionId, networkEnvironment.getPartitionManager(), networkEnvironment.getTaskEventDispatcher()); + break; + case REMOTE: + final RemoteAddress producerAddress = checkNotNull(partition.getProducerAddress(), "Missing producer address for remote intermediate result partition."); + inputChannels[channelIndex] = new RemoteInputChannel(reader, channelIndex, producerExecutionId, partitionId, producerAddress, networkEnvironment.getConnectionManager()); + break; + case UNKNOWN: + inputChannels[channelIndex] = new UnknownInputChannel(reader, channelIndex, producerExecutionId, partitionId, networkEnvironment.getPartitionManager(), networkEnvironment.getTaskEventDispatcher(), networkEnvironment.getConnectionManager()); + break; + } + reader.setInputChannel(partitionId, inputChannels[channelIndex]); + channelIndex++; + } + return reader; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java new file mode 100644 index 0000000000000..4994f13029832 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import com.google.common.collect.Maps; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.util.event.EventListener; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingQueue; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * Input gate wrapper to union the input from multiple input gates. + *

    + * Each input gate has input channels attached from which it reads data. At each input gate, the + * input channels have unique IDs from 0 (inclusive) to the number of input channels (exclusive). + * + *

    + * +---+---+      +---+---+---+
    + * | 0 | 1 |      | 0 | 1 | 2 |
    + * +--------------+--------------+
    + * | Input gate 0 | Input gate 1 |
    + * +--------------+--------------+
    + * 
    + * + * The union input gate maps these IDs from 0 to the *total* number of input channels across all + * unioned input gates, e.g. the channels of input gate 0 keep their original indexes and the + * channel indexes of input gate 1 are set off by 2 to 2--4. + * + *
    + * +---+---++---+---+---+
    + * | 0 | 1 || 2 | 3 | 4 |
    + * +--------------------+
    + * | Union input gate   |
    + * +--------------------+
    + * 
    + * + * It is possible to recursively union union input gates. + */ +public class UnionInputGate implements InputGate { + + /** The input gates to union. */ + private final InputGate[] inputGates; + + /** Data availability listener across all unioned input gates. */ + private final InputGateListener inputGateListener; + + /** The total number of input channels across all unioned input gates. */ + private final int totalNumberOfInputChannels; + + /** + * A mapping from input gate to (logical) channel index offset. Valid channel indexes go from 0 + * (inclusive) to the total number of input channels (exclusive). + */ + private final Map inputGateToIndexOffsetMap; + + /** Flag indicating whether partitions have been requested. */ + private boolean requestedPartitionsFlag; + + public UnionInputGate(InputGate... inputGates) { + this.inputGates = checkNotNull(inputGates); + checkArgument(inputGates.length > 1, "Union input gate should union at least two input gates."); + + this.inputGateToIndexOffsetMap = Maps.newHashMapWithExpectedSize(inputGates.length); + + int currentNumberOfInputChannels = 0; + + for (InputGate inputGate : inputGates) { + // The offset to use for buffer or event instances received from this input gate. + inputGateToIndexOffsetMap.put(checkNotNull(inputGate), currentNumberOfInputChannels); + + currentNumberOfInputChannels += inputGate.getNumberOfInputChannels(); + } + + this.totalNumberOfInputChannels = currentNumberOfInputChannels; + + this.inputGateListener = new InputGateListener(inputGates, this); + } + + /** + * Returns the total number of input channels across all unioned input gates. + */ + @Override + public int getNumberOfInputChannels() { + return totalNumberOfInputChannels; + } + + @Override + public boolean isFinished() { + for (InputGate inputGate : inputGates) { + if (!inputGate.isFinished()) { + return false; + } + } + + return true; + } + + @Override + public void requestPartitions() throws IOException { + if (!requestedPartitionsFlag) { + for (InputGate inputGate : inputGates) { + inputGate.requestPartitions(); + } + + requestedPartitionsFlag = true; + } + } + + @Override + public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException { + + // Make sure to request the partitions, if they have not been requested before. + requestPartitions(); + + final InputGate inputGate = inputGateListener.getNextInputGateToReadFrom(); + + final BufferOrEvent bufferOrEvent = inputGate.getNextBufferOrEvent(); + + // Set the channel index to identify the input channel (across all unioned input gates) + final int channelIndexOffset = inputGateToIndexOffsetMap.get(inputGate); + + bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex()); + + return bufferOrEvent; + } + + @Override + public void sendTaskEvent(TaskEvent event) throws IOException { + for (InputGate inputGate : inputGates) { + inputGate.sendTaskEvent(event); + } + } + + @Override + public void registerListener(EventListener listener) { + // This method is called from the consuming task thread. + inputGateListener.registerListener(listener); + } + + /** + * Data availability listener at all unioned input gates. + *

    + * The listener registers itself at each input gate and is notified for *each incoming buffer* + * at one of the unioned input gates. + */ + private static class InputGateListener implements EventListener { + + private final UnionInputGate unionInputGate; + + private final BlockingQueue inputGatesWithData = new LinkedBlockingQueue(); + + private final List> registeredListeners = new CopyOnWriteArrayList>(); + + public InputGateListener(InputGate[] inputGates, UnionInputGate unionInputGate) { + for (InputGate inputGate : inputGates) { + inputGate.registerListener(this); + } + + this.unionInputGate = unionInputGate; + } + + @Override + public void onEvent(InputGate inputGate) { + // This method is called from the input channel thread, which can be either the same + // thread as the consuming task thread or a different one. + inputGatesWithData.add(inputGate); + + for (int i = 0; i < registeredListeners.size(); i++) { + registeredListeners.get(i).onEvent(unionInputGate); + } + } + + InputGate getNextInputGateToReadFrom() throws InterruptedException { + return inputGatesWithData.take(); + } + + public void registerListener(EventListener listener) { + registeredListeners.add(checkNotNull(listener)); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java index f1fae895a9319..f8e42bac794f4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java @@ -20,9 +20,12 @@ import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.RemoteAddress; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionManager; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import java.io.IOException; @@ -35,8 +38,25 @@ */ public class UnknownInputChannel extends InputChannel { - public UnknownInputChannel(int channelIndex, ExecutionAttemptID producerExecutionId, IntermediateResultPartitionID partitionId, BufferReader reader) { - super(channelIndex, producerExecutionId, partitionId, reader); + private final IntermediateResultPartitionManager partitionManager; + + private final TaskEventDispatcher taskEventDispatcher; + + private final ConnectionManager connectionManager; + + public UnknownInputChannel( + SingleInputGate gate, int channelIndex, + ExecutionAttemptID producerExecutionId, + IntermediateResultPartitionID partitionId, + IntermediateResultPartitionManager partitionManager, + TaskEventDispatcher taskEventDispatcher, + ConnectionManager connectionManager) { + + super(gate, channelIndex, producerExecutionId, partitionId); + + this.partitionManager = partitionManager; + this.taskEventDispatcher = taskEventDispatcher; + this.connectionManager = connectionManager; } @Override @@ -83,10 +103,10 @@ public String toString() { // ------------------------------------------------------------------------ public RemoteInputChannel toRemoteInputChannel(RemoteAddress producerAddress) { - return new RemoteInputChannel(channelIndex, producerExecutionId, partitionId, reader, checkNotNull(producerAddress)); + return new RemoteInputChannel(inputGate, channelIndex, producerExecutionId, partitionId, checkNotNull(producerAddress), connectionManager); } public LocalInputChannel toLocalInputChannel() { - return new LocalInputChannel(channelIndex, producerExecutionId, partitionId, reader); + return new LocalInputChannel(inputGate, channelIndex, producerExecutionId, partitionId, partitionManager, taskEventDispatcher); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java index 6ce615d8eb346..b222cb5546d61 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java @@ -73,7 +73,7 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable implemen @Override public void registerInputOutput() { - this.headEventReader = new MutableRecordReader(getEnvironment().getReader(0)); + this.headEventReader = new MutableRecordReader(getEnvironment().getInputGate(0)); } @Override @@ -99,7 +99,7 @@ public void invoke() throws Exception { int numEventsTillEndOfSuperstep = taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0); eventHandler = new SyncEventHandler(numEventsTillEndOfSuperstep, aggregators, getEnvironment().getUserClassLoader()); - headEventReader.subscribeToTaskEvent(eventHandler, WorkerDoneEvent.class); + headEventReader.registerTaskEventListener(eventHandler, WorkerDoneEvent.class); IntegerRecord dummy = new IntegerRecord(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java index efb6af4cd7a4f..a5df4140a50ed 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java @@ -29,7 +29,7 @@ import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.api.reader.MutableReader; import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader; -import org.apache.flink.runtime.io.network.api.reader.UnionBufferReader; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.operators.chaining.ExceptionInChainedStubException; import org.apache.flink.runtime.operators.sort.UnilateralSortMerger; @@ -345,11 +345,10 @@ private void initInputReaders() throws Exception { numGates += groupSize; if (groupSize == 1) { // non-union case - inputReader = new MutableRecordReader>(getEnvironment().getReader(0)); + inputReader = new MutableRecordReader>(getEnvironment().getInputGate(0)); } else if (groupSize > 1){ // union case - UnionBufferReader reader = new UnionBufferReader(getEnvironment().getAllReaders()); - inputReader = new MutableRecordReader(reader); + inputReader = new MutableRecordReader(new UnionInputGate(getEnvironment().getAllInputGates())); } else { throw new Exception("Illegal input group size in task configuration: " + groupSize); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java index 2f2331fa5fad5..38b71a14203be 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/RegularPactTask.java @@ -37,12 +37,12 @@ import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.api.reader.MutableReader; import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader; -import org.apache.flink.runtime.io.network.api.reader.UnionBufferReader; import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; import org.apache.flink.runtime.io.network.api.writer.RecordWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memorymanager.MemoryManager; import org.apache.flink.runtime.messages.JobManagerMessages; @@ -719,15 +719,14 @@ protected void initInputReaders() throws Exception { if (groupSize == 1) { // non-union case - inputReaders[i] = new MutableRecordReader(getEnvironment().getReader(currentReaderOffset)); + inputReaders[i] = new MutableRecordReader(getEnvironment().getInputGate(currentReaderOffset)); } else if (groupSize > 1){ // union case - BufferReader[] readers = new BufferReader[groupSize]; + InputGate[] readers = new InputGate[groupSize]; for (int j = 0; j < groupSize; ++j) { - readers[j] = getEnvironment().getReader(currentReaderOffset + j); + readers[j] = getEnvironment().getInputGate(currentReaderOffset + j); } - UnionBufferReader reader = new UnionBufferReader(readers); - inputReaders[i] = new MutableRecordReader(reader); + inputReaders[i] = new MutableRecordReader(new UnionInputGate(readers)); } else { throw new Exception("Illegal input group size in task configuration: " + groupSize); } @@ -759,15 +758,14 @@ protected void initBroadcastInputReaders() throws Exception { final int groupSize = this.config.getBroadcastGroupSize(i); if (groupSize == 1) { // non-union case - broadcastInputReaders[i] = new MutableRecordReader(getEnvironment().getReader(currentReaderOffset)); + broadcastInputReaders[i] = new MutableRecordReader(getEnvironment().getInputGate(currentReaderOffset)); } else if (groupSize > 1){ // union case - BufferReader[] readers = new BufferReader[groupSize]; + InputGate[] readers = new InputGate[groupSize]; for (int j = 0; j < groupSize; ++j) { - readers[j] = getEnvironment().getReader(currentReaderOffset + j); + readers[j] = getEnvironment().getInputGate(currentReaderOffset + j); } - UnionBufferReader reader = new UnionBufferReader(readers); - broadcastInputReaders[i] = new MutableRecordReader(reader); + broadcastInputReaders[i] = new MutableRecordReader(new UnionInputGate(readers)); } else { throw new Exception("Illegal input group size in task configuration: " + groupSize); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 715515ed69d90..d628da11a85d3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -19,18 +19,18 @@ package org.apache.flink.runtime.taskmanager; import akka.actor.ActorRef; -import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.execution.RuntimeEnvironment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; import org.apache.flink.runtime.io.network.partition.IntermediateResultPartition; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.JobID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.memorymanager.MemoryManager; import org.apache.flink.runtime.messages.ExecutionGraphMessages; +import org.apache.flink.runtime.messages.JobManagerMessages; import org.apache.flink.runtime.profiling.TaskManagerProfiler; import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; @@ -363,8 +363,8 @@ public void unregisterProfiler(TaskManagerProfiler taskManagerProfiler) { // Intermediate result partitions // ------------------------------------------------------------------------ - public BufferReader[] getReaders() { - return environment != null ? environment.getAllReaders() : null; + public SingleInputGate[] getInputGates() { + return environment != null ? environment.getAllInputGates() : null; } public BufferWriter[] getWriters() { diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala index 9d157239fe259..898abd9575899 100644 --- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala +++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala @@ -532,7 +532,7 @@ import scala.collection.JavaConverters._ case Some(task) => val errors = partitionInfos flatMap { case (resultID, partitionInfo) => - Option(task.getEnvironment.getReaderById(resultID)) match { + Option(task.getEnvironment.getInputGateById(resultID)) match { case Some(reader) => Future { try { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/AbstractReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/AbstractReaderTest.java new file mode 100644 index 0000000000000..0d8183de84f03 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/AbstractReaderTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.api.reader; + +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; +import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.util.event.EventListener; +import org.junit.Test; +import org.mockito.Matchers; + +import java.io.IOException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the event handling behaviour. + */ +public class AbstractReaderTest { + + @Test + public void testTaskEvent() throws Exception { + final AbstractReader reader = new MockReader(createInputGate(1)); + + final EventListener listener1 = mock(EventListener.class); + final EventListener listener2 = mock(EventListener.class); + final EventListener listener3 = mock(EventListener.class); + + reader.registerTaskEventListener(listener1, TestTaskEvent1.class); + reader.registerTaskEventListener(listener2, TestTaskEvent2.class); + reader.registerTaskEventListener(listener3, TaskEvent.class); + + reader.handleEvent(new TestTaskEvent1()); // for listener1 only + reader.handleEvent(new TestTaskEvent2()); // for listener2 only + + verify(listener1, times(1)).onEvent(Matchers.any(TaskEvent.class)); + verify(listener2, times(1)).onEvent(Matchers.any(TaskEvent.class)); + verify(listener3, times(0)).onEvent(Matchers.any(TaskEvent.class)); + } + + @Test + public void testEndOfPartitionEvent() throws Exception { + final AbstractReader reader = new MockReader(createInputGate(1)); + + assertTrue(reader.handleEvent(new EndOfPartitionEvent())); + } + + /** + * Ensure that all end of superstep event related methods throw an Exception when used with a + * non-iterative reader. + */ + @Test + public void testExceptionsNonIterativeReader() throws Exception { + + final AbstractReader reader = new MockReader(createInputGate(4)); + + // Non-iterative reader cannot reach end of superstep + assertFalse(reader.hasReachedEndOfSuperstep()); + + try { + reader.startNextSuperstep(); + + fail("Did not throw expected exception when starting next superstep with non-iterative reader."); + } + catch (Throwable t) { + // All good, expected exception. + } + + try { + reader.handleEvent(new EndOfSuperstepEvent()); + + fail("Did not throw expected exception when handling end of superstep event with non-iterative reader."); + } + catch (Throwable t) { + // All good, expected exception. + } + } + + @Test + public void testEndOfSuperstepEventLogic() throws IOException { + + final int numberOfInputChannels = 4; + final AbstractReader reader = new MockReader(createInputGate(numberOfInputChannels)); + + reader.setIterativeReader(); + + try { + // The first superstep does not need not to be explicitly started + reader.startNextSuperstep(); + + fail("Did not throw expected exception when starting next superstep before receiving all end of superstep events."); + } + catch (Throwable t) { + // All good, expected exception. + } + + EndOfSuperstepEvent eos = new EndOfSuperstepEvent(); + + // One end of superstep event for each input channel. The superstep finishes with the last + // received event. + for (int i = 0; i < numberOfInputChannels - 1; i++) { + assertFalse(reader.handleEvent(eos)); + assertFalse(reader.hasReachedEndOfSuperstep()); + } + + assertTrue(reader.handleEvent(eos)); + assertTrue(reader.hasReachedEndOfSuperstep()); + + try { + // Verify exception, when receiving too many end of superstep events. + reader.handleEvent(eos); + + fail("Did not throw expected exception when receiving too many end of superstep events."); + } + catch (Throwable t) { + // All good, expected exception. + } + + // Start next superstep. + reader.startNextSuperstep(); + assertFalse(reader.hasReachedEndOfSuperstep()); + } + + private InputGate createInputGate(int numberOfInputChannels) { + final InputGate inputGate = mock(InputGate.class); + when(inputGate.getNumberOfInputChannels()).thenReturn(numberOfInputChannels); + + return inputGate; + } + + // ------------------------------------------------------------------------ + + private static class TestTaskEvent1 extends TaskEvent { + + @Override + public void write(DataOutputView out) throws IOException { + } + + @Override + public void read(DataInputView in) throws IOException { + } + } + + private static class TestTaskEvent2 extends TaskEvent { + + @Override + public void write(DataOutputView out) throws IOException { + } + + @Override + public void read(DataInputView in) throws IOException { + } + } + + private static class MockReader extends AbstractReader { + + protected MockReader(InputGate inputGate) { + super(inputGate); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderTest.java index f7f87af69d67b..f11c0db03accc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/BufferReaderTest.java @@ -18,32 +18,14 @@ package org.apache.flink.runtime.io.network.api.reader; -import com.google.common.base.Optional; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.runtime.deployment.PartitionInfo; import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.execution.RuntimeEnvironment; -import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.reader.MockBufferReader.TestTaskEvent; import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.BufferPool; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; -import org.apache.flink.runtime.io.network.buffer.BufferRecycler; -import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionManager; -import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionProvider; -import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; -import org.apache.flink.runtime.io.network.partition.consumer.UnknownInputChannel; -import org.apache.flink.runtime.io.network.partition.queue.IntermediateResultPartitionQueueIterator; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; -import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.io.network.util.MockSingleInputGate; +import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.util.event.EventListener; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Matchers; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -51,180 +33,80 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyObject; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(Task.class) public class BufferReaderTest { - @Test - public void testTaskEventNotification() throws IOException, InterruptedException { - final MockBufferReader mockReader = new MockBufferReader() - .readEvent().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); - - final BufferReader reader = mockReader.getMock(); - - // Task event listener to be notified... - reader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - consumeAndVerify(reader, 0); - - verifyListenerCalled(listener, 1); - } - @Test public void testGetNextBufferOrEvent() throws IOException, InterruptedException { - final MockBufferReader mockReader = new MockBufferReader() - .readBuffer().readBuffer().readEvent().readBuffer().readBuffer().readEvent().readBuffer().finish(); + final MockSingleInputGate inputGate = new MockSingleInputGate(1) + .readBuffer().readBuffer().readEvent() + .readBuffer().readBuffer().readEvent() + .readBuffer().readEndOfPartitionEvent(); + + final BufferReader reader = new BufferReader(inputGate.getInputGate()); // Task event listener to be notified... final EventListener listener = mock(EventListener.class); + reader.registerTaskEventListener(listener, TestTaskEvent.class); - final BufferReader reader = mockReader.getMock(); - - reader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - // Consume the reader - consumeAndVerify(reader, 5); + int numReadBuffers = 0; + while ((reader.getNextBuffer()) != null) { + numReadBuffers++; + } - verifyListenerCalled(listener, 2); + assertEquals(5, numReadBuffers); + verify(listener, times(2)).onEvent(any(TaskEvent.class)); } @Test public void testIterativeGetNextBufferOrEvent() throws IOException, InterruptedException { - final MockBufferReader mockReader = new MockBufferReader() - .readBuffer().readBuffer().readEvent().readBuffer().readBuffer().readEvent().readBuffer().finishSuperstep() - .readBuffer().readBuffer().readEvent().readBuffer().readBuffer().readEvent().readBuffer().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); + final MockSingleInputGate inputGate = new MockSingleInputGate(1) + .readBuffer().readBuffer().readEvent() + .readBuffer().readBuffer().readEvent() + .readBuffer().readEndOfSuperstepEvent() + .readBuffer().readBuffer().readEvent() + .readBuffer().readBuffer().readEvent() + .readBuffer().readEndOfPartitionEvent(); - final BufferReader reader = mockReader.getMock(); + final BufferReader reader = new BufferReader(inputGate.getInputGate()); // Set reader iterative reader.setIterativeReader(); // Task event listener to be notified... - reader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - // Consume the reader - consumeAndVerify(reader, 10, 1); - - verifyListenerCalled(listener, 4); - } - - @Test(expected = IOException.class) - public void testExceptionEndOfSuperstepEventWithNonIterativeReader() throws IOException, InterruptedException { - - final MockBufferReader mockReader = new MockBufferReader().finishSuperstep(); - - final BufferReader reader = mockReader.getMock(); - - // Should throw Exception, because it's a non-iterative reader - reader.getNextBufferBlocking(); - } - - @Test - public void testBackwardsEventWithUninitializedChannel() throws Exception { - // Setup environment - final NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); - final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - when(taskEventDispatcher.publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class))).thenReturn(true); - - final IntermediateResultPartitionManager partitionManager = mock(IntermediateResultPartitionManager.class); - - final IntermediateResultPartitionQueueIterator iterator = mock(IntermediateResultPartitionQueueIterator.class); - when(iterator.getNextBuffer()).thenReturn(new Buffer(new MemorySegment(new byte[1024]), mock(BufferRecycler.class))); - - final BufferPool bufferPool = mock(BufferPool.class); - when(bufferPool.getNumberOfRequiredMemorySegments()).thenReturn(2); - - when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - when(networkEnvironment.getPartitionManager()).thenReturn(partitionManager); - - when(partitionManager.getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class))).thenReturn(iterator); - - // Setup reader with one local and one unknown input channel - final IntermediateDataSetID resultId = new IntermediateDataSetID(); - final BufferReader reader = new BufferReader(mock(RuntimeEnvironment.class), networkEnvironment, resultId, 2, 0); - reader.setBufferPool(bufferPool); - - ExecutionAttemptID localProducer = new ExecutionAttemptID(); - IntermediateResultPartitionID localPartitionId = new IntermediateResultPartitionID(); - InputChannel local = new LocalInputChannel(0, localProducer, localPartitionId, reader); - - ExecutionAttemptID unknownProducer = new ExecutionAttemptID(); - IntermediateResultPartitionID unknownPartitionId = new IntermediateResultPartitionID(); - InputChannel unknown = new UnknownInputChannel(2, unknownProducer, unknownPartitionId, reader); - - reader.setInputChannel(localPartitionId, local); - reader.setInputChannel(unknownPartitionId, unknown); - - reader.requestPartitionsOnce(); - - // Just request the one local channel - verify(partitionManager, times(1)).getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class)); - - // Send event backwards and initialize unknown channel afterwards - final TaskEvent event = new TestTaskEvent(); - reader.sendTaskEvent(event); - - // Only the local channel can send out the record - verify(taskEventDispatcher, times(1)).publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class)); - - // After the update, the pending event should be send to local channel - reader.updateInputChannel(new PartitionInfo(unknownPartitionId, unknownProducer, PartitionInfo.PartitionLocation.LOCAL, null)); - - verify(partitionManager, times(2)).getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class)); - verify(taskEventDispatcher, times(2)).publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class)); - } - - // ------------------------------------------------------------------------ - - static void verifyListenerCalled(EventListener mockListener, int expectedNumCalls) { - verify(mockListener, times(expectedNumCalls)).onEvent(any(TestTaskEvent.class)); - } - - static void consumeAndVerify(BufferReaderBase reader, int expectedNumReadBuffers) throws IOException, InterruptedException { - consumeAndVerify(reader, expectedNumReadBuffers, 0); - } + final EventListener listener = mock(EventListener.class); + // Task event listener to be notified... + reader.registerTaskEventListener(listener, TestTaskEvent.class); - static void consumeAndVerify(BufferReaderBase reader, int expectedNumReadBuffers, int expectedNumReadIterations) throws IOException, InterruptedException { int numReadBuffers = 0; - int numIterations = 0; + int numEndOfSuperstepEvents = 0; while (true) { - Buffer buffer; - while ((buffer = reader.getNextBufferBlocking()) != null) { - buffer.recycle(); + Buffer buffer = reader.getNextBuffer(); + if (buffer != null) { numReadBuffers++; } - - if (reader.isFinished()) { - break; - } else if (reader.hasReachedEndOfSuperstep()) { reader.startNextSuperstep(); - numIterations++; + numEndOfSuperstepEvents++; } - else { - continue; + else if (reader.isFinished()) { + break; } } - assertEquals(expectedNumReadBuffers, numReadBuffers); - assertEquals(expectedNumReadIterations, numIterations); + assertEquals(10, numReadBuffers); + assertEquals(1, numEndOfSuperstepEvents); + + verify(listener, times(4)).onEvent(any(TaskEvent.class)); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockIteratorBufferReader.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/IteratorWrappingMockSingleInputGate.java similarity index 69% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockIteratorBufferReader.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/IteratorWrappingMockSingleInputGate.java index e6e31875becce..614fb8d56d94d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockIteratorBufferReader.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/IteratorWrappingMockSingleInputGate.java @@ -26,6 +26,9 @@ import org.apache.flink.runtime.io.network.api.serialization.SpanningRecordSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.util.MockInputChannel; +import org.apache.flink.runtime.io.network.util.MockSingleInputGate; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.MutableObjectIterator; import org.mockito.invocation.InvocationOnMock; @@ -33,11 +36,12 @@ import java.io.IOException; -import static com.google.common.base.Preconditions.checkState; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class MockIteratorBufferReader extends MockBufferReader { +public class IteratorWrappingMockSingleInputGate extends MockSingleInputGate { + + private final MockInputChannel inputChannel = new MockInputChannel(inputGate, 0); private final int bufferSize; @@ -47,24 +51,16 @@ public class MockIteratorBufferReader extends Mock private final T reuse; - public MockIteratorBufferReader(int bufferSize, Class recordType) throws IOException { - this.bufferSize = bufferSize; - - this.reuse = InstantiationUtil.instantiate(recordType); - } + public IteratorWrappingMockSingleInputGate(int bufferSize, Class recordType, MutableObjectIterator iterator) throws IOException, InterruptedException { + super(1, false); - public MockIteratorBufferReader(int bufferSize, Class recordType, MutableObjectIterator iterator) throws IOException { this.bufferSize = bufferSize; - this.reuse = InstantiationUtil.instantiate(recordType); wrapIterator(iterator); } - public MockIteratorBufferReader wrapIterator(MutableObjectIterator iterator) throws IOException { - checkState(inputIterator == null, "Iterator has already been set."); - checkState(stubbing == null, "There is already an ongoing stubbing from the MockBufferReader, which can't be mixed with an Iterator."); - + private IteratorWrappingMockSingleInputGate wrapIterator(MutableObjectIterator iterator) throws IOException, InterruptedException { inputIterator = iterator; serializer = new SpanningRecordSerializer(); @@ -78,29 +74,29 @@ public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { serializer.setNextBuffer(buffer); serializer.addRecord(reuse); - reader.onAvailableInputChannel(inputChannel); + inputGate.onAvailableBuffer(inputChannel.getInputChannel()); // Call getCurrentBuffer to ensure size is set return serializer.getCurrentBuffer(); } else { - // Return true after finishing - when(inputChannel.isReleased()).thenReturn(true); + + when(inputChannel.getInputChannel().isReleased()).thenReturn(true); return EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE); } } }; - stubbing = when(inputChannel.getNextBuffer()).thenAnswer(answer); + when(inputChannel.getInputChannel().getNextBuffer()).thenAnswer(answer); + + inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannel.getInputChannel()); return this; } - public MockIteratorBufferReader read() { - checkState(inputIterator != null && serializer != null, "Iterator/serializer has not been set. Call wrapIterator() first."); - - reader.onAvailableInputChannel(inputChannel); + public IteratorWrappingMockSingleInputGate read() { + inputGate.onAvailableBuffer(inputChannel.getInputChannel()); return this; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReaderTest.java deleted file mode 100644 index a10555747fce4..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/UnionBufferReaderTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.api.reader; - -import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.taskmanager.Task; -import org.apache.flink.runtime.util.event.EventListener; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import java.io.IOException; - -import static org.apache.flink.runtime.io.network.api.reader.BufferReaderTest.consumeAndVerify; -import static org.apache.flink.runtime.io.network.api.reader.BufferReaderTest.verifyListenerCalled; -import static org.apache.flink.runtime.io.network.api.reader.MockBufferReader.TestTaskEvent; -import static org.mockito.Mockito.mock; - -@RunWith(PowerMockRunner.class) -@PrepareForTest(Task.class) -public class UnionBufferReaderTest { - - @Test - public void testTaskEventNotifications() throws IOException, InterruptedException { - final MockBufferReader reader1 = new MockBufferReader(); - final MockBufferReader reader2 = new MockBufferReader(); - - final UnionBufferReader unionReader = new UnionBufferReader(reader1.getMock(), reader2.getMock()); - - reader1.readEvent().finish(); - reader2.readEvent().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); - - unionReader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - consumeAndVerify(unionReader, 0); - - verifyListenerCalled(listener, 2); - } - - @Test - public void testGetNextBufferOrEvent() throws IOException, InterruptedException { - final MockBufferReader reader1 = new MockBufferReader(); - final MockBufferReader reader2 = new MockBufferReader(); - - final UnionBufferReader unionReader = new UnionBufferReader(reader1.getMock(), reader2.getMock()); - - reader1.readBuffer().readBuffer().readEvent().readBuffer().readBuffer().readEvent().readBuffer().finish(); - - reader2.readBuffer().readBuffer().readEvent().readBuffer().readBuffer().readEvent().readBuffer().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); - - unionReader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - // Consume the reader - consumeAndVerify(unionReader, 10); - - verifyListenerCalled(listener, 4); - } - - @Test - public void testIterativeGetNextBufferOrEvent() throws IOException, InterruptedException { - final MockBufferReader reader1 = new MockBufferReader(); - final MockBufferReader reader2 = new MockBufferReader(); - - final UnionBufferReader unionReader = new UnionBufferReader(reader1.getMock(), reader2.getMock()); - - unionReader.setIterativeReader(); - - reader1.readBuffer().readBuffer().readEvent() - .readBuffer().readBuffer().readEvent() - .readBuffer().finishSuperstep().readBuffer().readBuffer() - .readEvent().readBuffer().readBuffer() - .readEvent().readBuffer().finish(); - - reader2.readBuffer().readBuffer().readEvent() - .readBuffer().readBuffer().readEvent() - .readBuffer().finishSuperstep().readBuffer().readBuffer() - .readEvent().readBuffer().readBuffer() - .readEvent().readBuffer().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); - unionReader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - // Consume the reader - consumeAndVerify(unionReader, 20, 1); - - verifyListenerCalled(listener, 8); - } - - @Test - public void testGetNextBufferUnionOfUnionReader() throws Exception { - final MockBufferReader reader1 = new MockBufferReader(); - final MockBufferReader reader2 = new MockBufferReader(); - - final UnionBufferReader unionReader = new UnionBufferReader(reader1.getMock(), reader2.getMock()); - - final MockBufferReader reader3 = new MockBufferReader(); - - final UnionBufferReader unionUnionReader = new UnionBufferReader(unionReader, reader3.getMock()); - - reader1.readBuffer().readBuffer().readBuffer().readEvent().readEvent().readBuffer().finish(); - - reader2.readEvent().readBuffer().readBuffer().readEvent().readBuffer().finish(); - - reader3.readBuffer().readBuffer().readEvent().readEvent().finish(); - - // Task event listener to be notified... - final EventListener listener = mock(EventListener.class); - unionUnionReader.subscribeToTaskEvent(listener, TestTaskEvent.class); - - // Consume the reader - consumeAndVerify(unionUnionReader, 9); - - verifyListenerCalled(listener, 6); - } -} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java index 4772a0ed7a830..5a20a4b9d80d7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java @@ -18,13 +18,10 @@ package org.apache.flink.runtime.io.network.api.serialization; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.runtime.event.task.AbstractEvent; -import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.junit.Test; -import java.io.IOException; import java.nio.ByteBuffer; import static org.junit.Assert.assertEquals; @@ -42,42 +39,4 @@ public void testSerializeDeserializeEvent() { assertEquals(expected, actual); } - - public static class TestTaskEvent extends TaskEvent { - - private double val0; - - private long val1; - - public TestTaskEvent() { - } - - public TestTaskEvent(double val0, long val1) { - this.val0 = val0; - this.val1 = val1; - } - - @Override - public void write(DataOutputView out) throws IOException { - out.writeDouble(val0); - out.writeLong(val1); - } - - @Override - public void read(DataInputView in) throws IOException { - val0 = in.readDouble(); - val1 = in.readLong(); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof TestTaskEvent) { - TestTaskEvent other = (TestTaskEvent) obj; - - return val0 == other.val0 && val1 == other.val1; - } - - return false; - } - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java new file mode 100644 index 0000000000000..e5c87e9aa3631 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import com.google.common.base.Optional; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.deployment.PartitionInfo; +import org.apache.flink.runtime.event.task.TaskEvent; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.network.ConnectionManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.partition.IntermediateResultPartitionManager; +import org.apache.flink.runtime.io.network.partition.queue.IntermediateResultPartitionQueueIterator; +import org.apache.flink.runtime.io.network.util.TestTaskEvent; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.junit.Test; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SingleInputGateTest { + + @Test + public void testBackwardsEventWithUninitializedChannel() throws Exception { + // Setup environment + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + when(taskEventDispatcher.publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class))).thenReturn(true); + + final IntermediateResultPartitionQueueIterator iterator = mock(IntermediateResultPartitionQueueIterator.class); + when(iterator.getNextBuffer()).thenReturn(new Buffer(new MemorySegment(new byte[1024]), mock(BufferRecycler.class))); + + final IntermediateResultPartitionManager partitionManager = mock(IntermediateResultPartitionManager.class); + when(partitionManager.getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class))).thenReturn(iterator); + + // Setup reader with one local and one unknown input channel + final IntermediateDataSetID resultId = new IntermediateDataSetID(); + + final SingleInputGate inputGate = new SingleInputGate(resultId, 0, 2); + final BufferPool bufferPool = mock(BufferPool.class); + when(bufferPool.getNumberOfRequiredMemorySegments()).thenReturn(2); + + inputGate.setBufferPool(bufferPool); + + // Local + ExecutionAttemptID localProducer = new ExecutionAttemptID(); + IntermediateResultPartitionID localPartitionId = new IntermediateResultPartitionID(); + + InputChannel local = new LocalInputChannel(inputGate, 0, localProducer, localPartitionId, partitionManager, taskEventDispatcher); + + // Unknown + ExecutionAttemptID unknownProducer = new ExecutionAttemptID(); + IntermediateResultPartitionID unknownPartitionId = new IntermediateResultPartitionID(); + + InputChannel unknown = new UnknownInputChannel(inputGate, 1, unknownProducer, unknownPartitionId, partitionManager, taskEventDispatcher, mock(ConnectionManager.class)); + + // Set channels + inputGate.setInputChannel(localPartitionId, local); + inputGate.setInputChannel(unknownPartitionId, unknown); + + // Request partitions + inputGate.requestPartitions(); + + // Only the local channel can request + verify(partitionManager, times(1)).getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class)); + + // Send event backwards and initialize unknown channel afterwards + final TaskEvent event = new TestTaskEvent(); + inputGate.sendTaskEvent(event); + + // Only the local channel can send out the event + verify(taskEventDispatcher, times(1)).publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class)); + + // After the update, the pending event should be send to local channel + inputGate.updateInputChannel(new PartitionInfo(unknownPartitionId, unknownProducer, PartitionInfo.PartitionLocation.LOCAL, null)); + + verify(partitionManager, times(2)).getIntermediateResultPartitionIterator(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), anyInt(), any(Optional.class)); + verify(taskEventDispatcher, times(2)).publish(any(ExecutionAttemptID.class), any(IntermediateResultPartitionID.class), any(TaskEvent.class)); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java new file mode 100644 index 0000000000000..c1db44d64f0de --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.consumer; + +import org.apache.flink.runtime.io.network.util.MockInputChannel; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class UnionInputGateTest { + + @Test + public void testChannelMapping() throws Exception { + + final SingleInputGate ig1 = new SingleInputGate(new IntermediateDataSetID(), 0, 3); + final SingleInputGate ig2 = new SingleInputGate(new IntermediateDataSetID(), 0, 5); + + final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2}); + + assertEquals(ig1.getNumberOfInputChannels() + ig2.getNumberOfInputChannels(), union.getNumberOfInputChannels()); + + final MockInputChannel[][] inputChannels = new MockInputChannel[][]{ + MockInputChannel.createInputChannels(ig1, 3), + MockInputChannel.createInputChannels(ig2, 5) + }; + + inputChannels[0][0].readBuffer(); // 0 => 0 + inputChannels[1][2].readBuffer(); // 2 => 5 + inputChannels[1][0].readBuffer(); // 0 => 3 + inputChannels[1][1].readBuffer(); // 1 => 4 + inputChannels[0][1].readBuffer(); // 1 => 1 + inputChannels[1][3].readBuffer(); // 3 => 6 + inputChannels[0][2].readBuffer(); // 1 => 2 + inputChannels[1][4].readBuffer(); // 4 => 7 + + assertEquals(0, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(5, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(3, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(4, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(1, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(6, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(2, union.getNextBufferOrEvent().getChannelIndex()); + assertEquals(7, union.getNextBufferOrEvent().getChannelIndex()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/queue/PipelinedPartitionQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/queue/PipelinedPartitionQueueTest.java index e85a7f722fc73..118b00bf7b34b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/queue/PipelinedPartitionQueueTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/queue/PipelinedPartitionQueueTest.java @@ -23,9 +23,9 @@ import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; -import org.apache.flink.runtime.io.network.partition.MockConsumer; -import org.apache.flink.runtime.io.network.partition.MockNotificationListener; -import org.apache.flink.runtime.io.network.partition.MockProducer; +import org.apache.flink.runtime.io.network.util.MockConsumer; +import org.apache.flink.runtime.io.network.util.MockNotificationListener; +import org.apache.flink.runtime.io.network.util.MockProducer; import org.apache.flink.runtime.io.network.partition.queue.IntermediateResultPartitionQueueIterator.AlreadySubscribedException; import org.apache.flink.runtime.util.event.NotificationListener; import org.junit.Before; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockConsumer.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockConsumer.java similarity index 98% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockConsumer.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockConsumer.java index 79494aa033abf..62375a62d56a1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockConsumer.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockConsumer.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.io.network.partition; +package org.apache.flink.runtime.io.network.util; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.io.network.buffer.Buffer; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockBufferReader.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockInputChannel.java similarity index 57% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockBufferReader.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockInputChannel.java index d439eeb6da355..301169ab123a7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/reader/MockBufferReader.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockInputChannel.java @@ -16,111 +16,115 @@ * limitations under the License. */ -package org.apache.flink.runtime.io.network.api.reader; +package org.apache.flink.runtime.io.network.util; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.execution.RuntimeEnvironment; -import org.apache.flink.runtime.io.network.MockNetworkEnvironment; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; -import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.OngoingStubbing; import java.io.IOException; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class MockBufferReader { +/** + * A mocked input channel. + */ +public class MockInputChannel { - protected final BufferReader reader; + private final InputChannel mock = Mockito.mock(InputChannel.class); - protected final InputChannel inputChannel = mock(InputChannel.class); + private final SingleInputGate inputGate; // Abusing Mockito here... ;) protected OngoingStubbing stubbing; - public MockBufferReader() throws IOException { - reader = new BufferReader(mock(RuntimeEnvironment.class), MockNetworkEnvironment.getMock(), new IntermediateDataSetID(), 1, 0); - reader.setInputChannel(new IntermediateResultPartitionID(), inputChannel); + public MockInputChannel(SingleInputGate inputGate, int channelIndex) { + checkArgument(channelIndex >= 0); + this.inputGate = checkNotNull(inputGate); + + when(mock.getChannelIndex()).thenReturn(channelIndex); } - MockBufferReader read(Buffer buffer) throws IOException { + public MockInputChannel read(Buffer buffer) throws IOException { if (stubbing == null) { - stubbing = when(inputChannel.getNextBuffer()).thenReturn(buffer); + stubbing = when(mock.getNextBuffer()).thenReturn(buffer); } else { stubbing = stubbing.thenReturn(buffer); } - reader.onAvailableInputChannel(inputChannel); + inputGate.onAvailableBuffer(mock); return this; } - MockBufferReader readBuffer() throws IOException { + public MockInputChannel readBuffer() throws IOException { final Buffer buffer = mock(Buffer.class); when(buffer.isBuffer()).thenReturn(true); return read(buffer); } - MockBufferReader readEvent() throws IOException { + public MockInputChannel readEvent() throws IOException { return read(EventSerializer.toBuffer(new TestTaskEvent())); } - MockBufferReader finishSuperstep() throws IOException { + public MockInputChannel readEndOfSuperstepEvent() throws IOException { return read(EventSerializer.toBuffer(EndOfSuperstepEvent.INSTANCE)); } - MockBufferReader finish() throws IOException { + public MockInputChannel readEndOfPartitionEvent() throws IOException { final Answer answer = new Answer() { @Override public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { // Return true after finishing - when(inputChannel.isReleased()).thenReturn(true); + when(mock.isReleased()).thenReturn(true); return EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE); } }; if (stubbing == null) { - stubbing = when(inputChannel.getNextBuffer()).thenAnswer(answer); + stubbing = when(mock.getNextBuffer()).thenAnswer(answer); } else { stubbing = stubbing.thenAnswer(answer); } - reader.onAvailableInputChannel(inputChannel); + inputGate.onAvailableBuffer(mock); return this; } - public BufferReader getMock() { - return reader; + public InputChannel getInputChannel() { + return mock; } // ------------------------------------------------------------------------ - public static class TestTaskEvent extends TaskEvent { + public static MockInputChannel[] createInputChannels(SingleInputGate inputGate, int numberOfInputChannels) { + checkNotNull(inputGate); + checkArgument(numberOfInputChannels > 0); - public TestTaskEvent() { - } + MockInputChannel[] mocks = new MockInputChannel[numberOfInputChannels]; - @Override - public void write(DataOutputView out) throws IOException { - } + for (int i = 0; i < numberOfInputChannels; i++) { + mocks[i] = new MockInputChannel(inputGate, i); - @Override - public void read(DataInputView in) throws IOException { + inputGate.setInputChannel(new IntermediateResultPartitionID(), mocks[i].getInputChannel()); } + + return mocks; } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockNotificationListener.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockNotificationListener.java similarity index 96% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockNotificationListener.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockNotificationListener.java index 928ac51c7529f..56e002538793b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockNotificationListener.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockNotificationListener.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.io.network.partition; +package org.apache.flink.runtime.io.network.util; import org.apache.flink.runtime.util.event.NotificationListener; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockProducer.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockProducer.java similarity index 98% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockProducer.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockProducer.java index 1cfe75d65eafc..44d8ffe2dc9b4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockProducer.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockProducer.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.runtime.io.network.partition; +package org.apache.flink.runtime.io.network.util; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.io.network.buffer.Buffer; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockSingleInputGate.java new file mode 100644 index 0000000000000..3c708acb10f14 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/MockSingleInputGate.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.util; + +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkElementIndex; +import static org.mockito.Mockito.spy; + +public class MockSingleInputGate { + + protected final SingleInputGate inputGate; + + protected final MockInputChannel[] inputChannels; + + public MockSingleInputGate(int numberOfInputChannels) { + this(numberOfInputChannels, true); + } + + public MockSingleInputGate(int numberOfInputChannels, boolean initialize) { + checkArgument(numberOfInputChannels >= 1); + + this.inputGate = spy(new SingleInputGate(new IntermediateDataSetID(), 0, numberOfInputChannels)); + + this.inputChannels = new MockInputChannel[numberOfInputChannels]; + + if (initialize) { + for (int i = 0; i < numberOfInputChannels; i++) { + inputChannels[i] = new MockInputChannel(inputGate, i); + inputGate.setInputChannel(new IntermediateResultPartitionID(), inputChannels[i].getInputChannel()); + } + } + } + + public MockSingleInputGate read(Buffer buffer, int channelIndex) throws IOException { + checkElementIndex(channelIndex, inputGate.getNumberOfInputChannels()); + + inputChannels[channelIndex].read(buffer); + + return this; + } + + public MockSingleInputGate readBuffer() throws IOException { + return readBuffer(0); + } + + public MockSingleInputGate readBuffer(int channelIndex) throws IOException { + inputChannels[channelIndex].readBuffer(); + + return this; + } + + public MockSingleInputGate readEvent() throws IOException { + return readEvent(0); + } + + public MockSingleInputGate readEvent(int channelIndex) throws IOException { + inputChannels[channelIndex].readEvent(); + + return this; + } + + public MockSingleInputGate readEndOfSuperstepEvent() throws IOException { + for (MockInputChannel inputChannel : inputChannels) { + inputChannel.readEndOfSuperstepEvent(); + } + + return this; + } + + public MockSingleInputGate readEndOfSuperstepEvent(int channelIndex) throws IOException { + inputChannels[channelIndex].readEndOfSuperstepEvent(); + + return this; + } + + public MockSingleInputGate readEndOfPartitionEvent() throws IOException { + for (MockInputChannel inputChannel : inputChannels) { + inputChannel.readEndOfPartitionEvent(); + } + + return this; + } + + public MockSingleInputGate readEndOfPartitionEvent(int channelIndex) throws IOException { + inputChannels[channelIndex].readEndOfPartitionEvent(); + + return this; + } + + public SingleInputGate getInputGate() { + return inputGate; + } + + // ------------------------------------------------------------------------ + + public List readAllChannels() throws IOException { + final List readOrder = new ArrayList(inputChannels.length); + + for (int i = 0; i < inputChannels.length; i++) { + readOrder.add(i); + } + + Collections.shuffle(readOrder); + + for (int channelIndex : readOrder) { + inputChannels[channelIndex].readBuffer(); + } + + return readOrder; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestTaskEvent.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestTaskEvent.java new file mode 100644 index 0000000000000..4f547aa8d092c --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/util/TestTaskEvent.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.util; + +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.event.task.TaskEvent; + +import java.io.IOException; + +public class TestTaskEvent extends TaskEvent { + + private double val0; + + private long val1; + + public TestTaskEvent() { + this(0, 0); + } + + public TestTaskEvent(double val0, long val1) { + this.val0 = val0; + this.val1 = val1; + } + + @Override + public void write(DataOutputView out) throws IOException { + out.writeDouble(val0); + out.writeLong(val1); + } + + @Override + public void read(DataInputView in) throws IOException { + val0 = in.readDouble(); + val1 = in.readLong(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof TestTaskEvent) { + TestTaskEvent other = (TestTaskEvent) obj; + + return val0 == other.val0 && val1 == other.val1; + } + + return false; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/DataSinkTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/DataSinkTaskTest.java index 9ce0d50809477..1a4b7f131b2c3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/DataSinkTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/DataSinkTaskTest.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeutils.record.RecordComparatorFactory; import org.apache.flink.api.java.record.io.DelimitedOutputFormat; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.io.network.api.reader.MockIteratorBufferReader; +import org.apache.flink.runtime.io.network.api.reader.IteratorWrappingMockSingleInputGate; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; import org.apache.flink.runtime.operators.testutils.InfiniteInputIterator; import org.apache.flink.runtime.operators.testutils.TaskCancelThread; @@ -144,7 +144,7 @@ public void testUnionDataSinkTask() { super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE); - MockIteratorBufferReader[] readers = new MockIteratorBufferReader[4]; + IteratorWrappingMockSingleInputGate[] readers = new IteratorWrappingMockSingleInputGate[4]; readers[0] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, 0, 0, false), 0, false); readers[1] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, keyCnt, 0, false), 0, false); readers[2] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, keyCnt * 2, 0, false), 0, false); @@ -157,7 +157,7 @@ public void testUnionDataSinkTask() { try { // For the union reader to work, we need to start notifications *after* the union reader // has been initialized. - for (MockIteratorBufferReader reader : readers) { + for (IteratorWrappingMockSingleInputGate reader : readers) { reader.read(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 7aab0500782f7..7fb13e377e3c8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -27,14 +27,14 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; -import org.apache.flink.runtime.io.network.api.reader.BufferReader; -import org.apache.flink.runtime.io.network.api.reader.MockIteratorBufferReader; +import org.apache.flink.runtime.io.network.api.reader.IteratorWrappingMockSingleInputGate; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.writer.BufferWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; @@ -69,7 +69,7 @@ public class MockEnvironment implements Environment { private final Configuration taskConfiguration; - private final List inputs; + private final List inputs; private final List outputs; @@ -82,7 +82,7 @@ public class MockEnvironment implements Environment { public MockEnvironment(long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { this.jobConfiguration = new Configuration(); this.taskConfiguration = new Configuration(); - this.inputs = new LinkedList(); + this.inputs = new LinkedList(); this.outputs = new LinkedList(); this.memManager = new DefaultMemoryManager(memorySize, 1); @@ -91,11 +91,11 @@ public MockEnvironment(long memorySize, MockInputSplitProvider inputSplitProvide this.bufferSize = bufferSize; } - public MockIteratorBufferReader addInput(MutableObjectIterator inputIterator) { + public IteratorWrappingMockSingleInputGate addInput(MutableObjectIterator inputIterator) { try { - final MockIteratorBufferReader reader = new MockIteratorBufferReader(bufferSize, Record.class, inputIterator); + final IteratorWrappingMockSingleInputGate reader = new IteratorWrappingMockSingleInputGate(bufferSize, Record.class, inputIterator); - inputs.add(reader.getMock()); + inputs.add(reader.getInputGate()); return reader; } @@ -235,13 +235,15 @@ public BufferWriter[] getAllWriters() { } @Override - public BufferReader getReader(int index) { + public InputGate getInputGate(int index) { return inputs.get(index); } @Override - public BufferReader[] getAllReaders() { - return inputs.toArray(new BufferReader[inputs.size()]); + public InputGate[] getAllInputGates() { + InputGate[] gates = new InputGate[inputs.size()]; + inputs.toArray(gates); + return gates; } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java index 23cb23beb4d66..e0776d9e68fda 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/TaskTestBase.java @@ -27,7 +27,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FileSystem.WriteMode; import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.io.network.api.reader.MockIteratorBufferReader; +import org.apache.flink.runtime.io.network.api.reader.IteratorWrappingMockSingleInputGate; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memorymanager.MemoryManager; import org.apache.flink.runtime.operators.PactDriver; @@ -55,14 +55,14 @@ public void initEnvironment(long memorySize, int bufferSize) { this.mockEnv = new MockEnvironment(this.memorySize, this.inputSplitProvider, bufferSize); } - public MockIteratorBufferReader addInput(MutableObjectIterator input, int groupId) { - final MockIteratorBufferReader reader = addInput(input, groupId, true); + public IteratorWrappingMockSingleInputGate addInput(MutableObjectIterator input, int groupId) { + final IteratorWrappingMockSingleInputGate reader = addInput(input, groupId, true); return reader; } - public MockIteratorBufferReader addInput(MutableObjectIterator input, int groupId, boolean read) { - final MockIteratorBufferReader reader = this.mockEnv.addInput(input); + public IteratorWrappingMockSingleInputGate addInput(MutableObjectIterator input, int groupId, boolean read) { + final IteratorWrappingMockSingleInputGate reader = this.mockEnv.addInput(input); TaskConfig conf = new TaskConfig(this.mockEnv.getTaskConfiguration()); conf.addInputToGroup(groupId); conf.setInputSerializer(RecordSerializerFactory.get(), groupId); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/event/EventNotificationHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/event/TaskEventHandlerTest.java similarity index 91% rename from flink-runtime/src/test/java/org/apache/flink/runtime/util/event/EventNotificationHandlerTest.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/util/event/TaskEventHandlerTest.java index 625b93ff16cd5..5c6aeb719ae8d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/event/EventNotificationHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/event/TaskEventHandlerTest.java @@ -25,13 +25,14 @@ import org.apache.flink.runtime.event.task.TaskEvent; import org.apache.flink.runtime.event.task.IntegerTaskEvent; import org.apache.flink.runtime.event.task.StringTaskEvent; +import org.apache.flink.runtime.io.network.api.TaskEventHandler; import org.junit.Test; /** - * This class contains unit tests for the {@link EventNotificationHandler}. + * This class contains unit tests for the {@link TaskEventHandler}. * */ -public class EventNotificationHandlerTest { +public class TaskEventHandlerTest { /** * A test implementation of an {@link EventListener}. * @@ -66,12 +67,12 @@ public TaskEvent getLastReceivedEvent() { } /** - * Tests the publish/subscribe mechanisms implemented in the {@link EventNotificationHandler}. + * Tests the publish/subscribe mechanisms implemented in the {@link TaskEventHandler}. */ @Test public void testEventNotificationManager() { - final EventNotificationHandler evm = new EventNotificationHandler(); + final TaskEventHandler evm = new TaskEventHandler(); final TestEventListener listener = new TestEventListener(); evm.subscribe(listener, StringTaskEvent.class); diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala index 57a61da6e31db..9bb1a3bb1a5d5 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala @@ -72,7 +72,8 @@ object Tasks { var reader: RecordReader[IntegerRecord] = _ var writer: RecordWriter[IntegerRecord] = _ override def registerInputOutput(): Unit = { - reader = new RecordReader[IntegerRecord](getEnvironment.getReader(0), classOf[IntegerRecord]) + reader = new RecordReader[IntegerRecord](getEnvironment.getInputGate(0), + classOf[IntegerRecord]) writer = new RecordWriter[IntegerRecord](getEnvironment.getWriter(0)) } @@ -101,7 +102,7 @@ object Tasks { override def registerInputOutput(): Unit = { val env = getEnvironment - reader = new RecordReader[IntegerRecord](env.getReader(0), classOf[IntegerRecord]) + reader = new RecordReader[IntegerRecord](env.getInputGate(0), classOf[IntegerRecord]) } override def invoke(): Unit = { @@ -158,7 +159,7 @@ object Tasks { override def registerInputOutput(): Unit = { val env = getEnvironment - reader = new RecordReader[IntegerRecord](env.getReader(0), classOf[IntegerRecord]) + reader = new RecordReader[IntegerRecord](env.getInputGate(0), classOf[IntegerRecord]) } override def invoke(): Unit = { @@ -173,8 +174,8 @@ object Tasks { override def registerInputOutput(): Unit = { val env = getEnvironment - reader1 = new RecordReader[IntegerRecord](env.getReader(0), classOf[IntegerRecord]) - reader2 = new RecordReader[IntegerRecord](env.getReader(1), classOf[IntegerRecord]) + reader1 = new RecordReader[IntegerRecord](env.getInputGate(0), classOf[IntegerRecord]) + reader2 = new RecordReader[IntegerRecord](env.getInputGate(1), classOf[IntegerRecord]) } override def invoke(): Unit = { @@ -191,9 +192,9 @@ object Tasks { override def registerInputOutput(): Unit = { val env = getEnvironment - reader1 = new RecordReader[IntegerRecord](env.getReader(0), classOf[IntegerRecord]) - reader2 = new RecordReader[IntegerRecord](env.getReader(1), classOf[IntegerRecord]) - reader3 = new RecordReader[IntegerRecord](env.getReader(2), classOf[IntegerRecord]) + reader1 = new RecordReader[IntegerRecord](env.getInputGate(0), classOf[IntegerRecord]) + reader2 = new RecordReader[IntegerRecord](env.getInputGate(1), classOf[IntegerRecord]) + reader3 = new RecordReader[IntegerRecord](env.getInputGate(2), classOf[IntegerRecord]) } override def invoke(): Unit = { @@ -239,7 +240,7 @@ object Tasks { class ExceptionReceiver extends AbstractInvokable { override def registerInputOutput(): Unit = { - new RecordReader[IntegerRecord](getEnvironment.getReader(0), classOf[IntegerRecord]) + new RecordReader[IntegerRecord](getEnvironment.getInputGate(0), classOf[IntegerRecord]) } override def invoke(): Unit = { @@ -280,7 +281,7 @@ object Tasks { class BlockingReceiver extends AbstractInvokable { override def registerInputOutput(): Unit = { - new RecordReader[IntegerRecord](getEnvironment.getReader(0), classOf[IntegerRecord]) + new RecordReader[IntegerRecord](getEnvironment.getInputGate(0), classOf[IntegerRecord]) } override def invoke(): Unit = { diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/CoStreamVertex.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/CoStreamVertex.java index de4660ad4af59..df7bcad093b75 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/CoStreamVertex.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/CoStreamVertex.java @@ -17,11 +17,8 @@ package org.apache.flink.streaming.api.streamvertex; -import java.util.ArrayList; - -import org.apache.flink.runtime.io.network.api.reader.BufferReader; -import org.apache.flink.runtime.io.network.api.reader.BufferReaderBase; -import org.apache.flink.runtime.io.network.api.reader.UnionBufferReader; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.streaming.api.invokable.operator.co.CoInvokable; import org.apache.flink.streaming.api.streamrecord.StreamRecord; @@ -31,6 +28,8 @@ import org.apache.flink.streaming.io.IndexedReaderIterator; import org.apache.flink.util.MutableObjectIterator; +import java.util.ArrayList; + public class CoStreamVertex extends StreamVertex { protected StreamRecordSerializer inputDeserializer1 = null; @@ -77,12 +76,12 @@ protected void setConfigInputs() throws StreamVertexException { int numberOfInputs = configuration.getNumberOfInputs(); - ArrayList inputList1 = new ArrayList(); - ArrayList inputList2 = new ArrayList(); + ArrayList inputList1 = new ArrayList(); + ArrayList inputList2 = new ArrayList(); for (int i = 0; i < numberOfInputs; i++) { int inputType = configuration.getInputIndex(i); - BufferReader reader = getEnvironment().getReader(i); + InputGate reader = getEnvironment().getInputGate(i); switch (inputType) { case 1: inputList1.add(reader); @@ -95,11 +94,11 @@ protected void setConfigInputs() throws StreamVertexException { } } - final BufferReaderBase reader1 = inputList1.size() == 1 ? inputList1.get(0) - : new UnionBufferReader(inputList1.toArray(new BufferReader[inputList1.size()])); + final InputGate reader1 = inputList1.size() == 1 ? inputList1.get(0) + : new UnionInputGate(inputList1.toArray(new InputGate[inputList1.size()])); - final BufferReaderBase reader2 = inputList2.size() == 1 ? inputList2.get(0) - : new UnionBufferReader(inputList2.toArray(new BufferReader[inputList2.size()])); + final InputGate reader2 = inputList2.size() == 1 ? inputList2.get(0) + : new UnionInputGate(inputList2.toArray(new InputGate[inputList2.size()])); coReader = new CoRecordReader>, DeserializationDelegate>>( reader1, reader2); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/InputHandler.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/InputHandler.java index de3fd2b0485da..73dbfce4920d1 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/InputHandler.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/streamvertex/InputHandler.java @@ -19,7 +19,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.io.network.api.reader.MutableReader; -import org.apache.flink.runtime.io.network.api.reader.UnionBufferReader; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.plugable.DeserializationDelegate; import org.apache.flink.streaming.api.StreamConfig; import org.apache.flink.streaming.api.streamrecord.StreamRecord; @@ -54,14 +54,12 @@ protected void setConfigInputs() throws StreamVertexException { if (numberOfInputs > 0) { if (numberOfInputs < 2) { - inputs = new IndexedMutableReader>>( - streamVertex.getEnvironment().getReader(0)); + streamVertex.getEnvironment().getInputGate(0)); } else { - UnionBufferReader reader = new UnionBufferReader(streamVertex.getEnvironment() - .getAllReaders()); - inputs = new IndexedMutableReader>>(reader); + inputs = new IndexedMutableReader>>( + new UnionInputGate(streamVertex.getEnvironment().getAllInputGates())); } inputIter = createInputIterator(); diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/CoRecordReader.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/CoRecordReader.java index 0b1b37329ae42..bb3a6598879cb 100755 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/CoRecordReader.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/CoRecordReader.java @@ -18,13 +18,13 @@ package org.apache.flink.streaming.io; import org.apache.flink.core.io.IOReadableWritable; -import org.apache.flink.runtime.event.task.TaskEvent; -import org.apache.flink.runtime.io.network.api.reader.BufferReaderBase; +import org.apache.flink.runtime.io.network.api.reader.AbstractReader; import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader; -import org.apache.flink.runtime.io.network.api.reader.ReaderBase; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; -import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; import org.apache.flink.runtime.util.event.EventListener; import java.io.IOException; @@ -36,11 +36,11 @@ * types to read records effectively. */ @SuppressWarnings("rawtypes") -public class CoRecordReader implements ReaderBase, EventListener { +public class CoRecordReader extends AbstractReader implements EventListener { - private final BufferReaderBase bufferReader1; + private final InputGate bufferReader1; - private final BufferReaderBase bufferReader2; + private final InputGate bufferReader2; private final BlockingQueue availableRecordReaders = new LinkedBlockingQueue(); @@ -57,7 +57,9 @@ public class CoRecordReader(); } - bufferReader1.subscribeToReader(this); - bufferReader2.subscribeToReader(this); + bufferReader1.registerListener(this); + bufferReader2.registerListener(this); } public void requestPartitionsOnce() throws IOException { if (!hasRequestedPartitions) { - bufferReader1.requestPartitionsOnce(); - bufferReader2.requestPartitionsOnce(); + bufferReader1.requestPartitions(); + bufferReader2.requestPartitions(); hasRequestedPartitions = true; } @@ -115,18 +117,20 @@ protected int getNextRecord(T1 target1, T2 target2) throws IOException, Interrup return 1; } } + else { - final Buffer nextBuffer = bufferReader1.getNextBufferBlocking(); - final int channelIndex = bufferReader1.getChannelIndexOfLastBuffer(); + final BufferOrEvent boe = bufferReader1.getNextBufferOrEvent(); - if (nextBuffer == null) { - currentReaderIndex = 0; + if (boe.isBuffer()) { + reader1currentRecordDeserializer = reader1RecordDeserializers[boe.getChannelIndex()]; + reader1currentRecordDeserializer.setNextBuffer(boe.getBuffer()); + } + else if (handleEvent(boe.getEvent())) { + currentReaderIndex = 0; - break; + break; + } } - - reader1currentRecordDeserializer = reader1RecordDeserializers[channelIndex]; - reader1currentRecordDeserializer.setNextBuffer(nextBuffer); } } else if (currentReaderIndex == 2) { @@ -145,18 +149,19 @@ else if (currentReaderIndex == 2) { return 2; } } + else { + final BufferOrEvent boe = bufferReader2.getNextBufferOrEvent(); - final Buffer nextBuffer = bufferReader2.getNextBufferBlocking(); - final int channelIndex = bufferReader2.getChannelIndexOfLastBuffer(); - - if (nextBuffer == null) { - currentReaderIndex = 0; + if (boe.isBuffer()) { + reader2currentRecordDeserializer = reader2RecordDeserializers[boe.getChannelIndex()]; + reader2currentRecordDeserializer.setNextBuffer(boe.getBuffer()); + } + else if (handleEvent(boe.getEvent())) { + currentReaderIndex = 0; - break; + break; + } } - - reader2currentRecordDeserializer = reader2RecordDeserializers[channelIndex]; - reader2currentRecordDeserializer.setNextBuffer(nextBuffer); } } else { @@ -174,7 +179,7 @@ private int getNextReaderIndexBlocking() throws InterruptedException { // ------------------------------------------------------------------------ @Override - public void onEvent(BufferReaderBase bufferReader) { + public void onEvent(InputGate bufferReader) { if (bufferReader == bufferReader1) { availableRecordReaders.add(1); } @@ -182,40 +187,4 @@ else if (bufferReader == bufferReader2) { availableRecordReaders.add(2); } } - - // ------------------------------------------------------------------------ - - @Override - public boolean isFinished() { - return bufferReader1.isFinished() && bufferReader2.isFinished(); - } - - @Override - public void subscribeToTaskEvent(EventListener eventListener, Class eventType) { - bufferReader1.subscribeToTaskEvent(eventListener, eventType); - bufferReader2.subscribeToTaskEvent(eventListener, eventType); - } - - @Override - public void sendTaskEvent(TaskEvent event) throws IOException, InterruptedException { - bufferReader1.sendTaskEvent(event); - bufferReader2.sendTaskEvent(event); - } - - @Override - public void setIterativeReader() { - bufferReader1.setIterativeReader(); - bufferReader2.setIterativeReader(); - } - - @Override - public void startNextSuperstep() { - bufferReader1.startNextSuperstep(); - bufferReader2.startNextSuperstep(); - } - - @Override - public boolean hasReachedEndOfSuperstep() { - return bufferReader1.hasReachedEndOfSuperstep() && bufferReader2.hasReachedEndOfSuperstep(); - } } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedMutableReader.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedMutableReader.java index 41781308d075b..175dba2627eb1 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedMutableReader.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedMutableReader.java @@ -19,22 +19,18 @@ package org.apache.flink.streaming.io; import org.apache.flink.core.io.IOReadableWritable; -import org.apache.flink.runtime.io.network.api.reader.BufferReaderBase; import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; public class IndexedMutableReader extends MutableRecordReader { - BufferReaderBase reader; + InputGate reader; - public IndexedMutableReader(BufferReaderBase reader) { + public IndexedMutableReader(InputGate reader) { super(reader); this.reader = reader; } - public int getLastChannelIndex() { - return reader.getChannelIndexOfLastBuffer(); - } - public int getNumberOfInputChannels() { return reader.getNumberOfInputChannels(); } diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedReaderIterator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedReaderIterator.java index 6c8187baf88e6..18cdd4e5a600f 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedReaderIterator.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/io/IndexedReaderIterator.java @@ -24,19 +24,10 @@ public class IndexedReaderIterator extends ReaderIterator { - private IndexedMutableReader> reader; - - public IndexedReaderIterator(IndexedMutableReader> reader, + public IndexedReaderIterator( + IndexedMutableReader> reader, TypeSerializer serializer) { - super(reader, serializer); - this.reader = reader; - } - public int getLastChannelIndex() { - return reader.getLastChannelIndex(); - } - - public int getNumberOfInputChannels() { - return reader.getNumberOfInputChannels(); + super(reader, serializer); } -} \ No newline at end of file +} diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java index 03038b3bdf10f..4b13165d6861b 100644 --- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java @@ -86,11 +86,6 @@ public StreamRecord next() throws IOException { return null; } } - - @Override - public int getLastChannelIndex() { - return 0; - } } public List getOutputs() { diff --git a/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java b/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java index f52de2217f813..36c7cbad50fe6 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java @@ -200,7 +200,7 @@ public static class SpeedTestForwarder extends AbstractInvokable { @Override public void registerInputOutput() { - this.reader = new RecordReader(getEnvironment().getReader(0), SpeedTestRecord.class); + this.reader = new RecordReader(getEnvironment().getInputGate(0), SpeedTestRecord.class); this.writer = new RecordWriter(getEnvironment().getWriter(0)); } @@ -222,7 +222,7 @@ public static class SpeedTestConsumer extends AbstractInvokable { @Override public void registerInputOutput() { - this.reader = new RecordReader(getEnvironment().getReader(0), SpeedTestRecord.class); + this.reader = new RecordReader(getEnvironment().getInputGate(0), SpeedTestRecord.class); } @Override