diff --git a/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java b/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java index 76f5eca8494b0..842982ac5846a 100644 --- a/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java +++ b/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java @@ -19,18 +19,19 @@ package org.apache.flink.connector.datagen.source; import org.apache.flink.api.common.eventtime.WatermarkStrategy; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.state.CheckpointListener; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.connector.source.SourceReaderContext; import org.apache.flink.api.connector.source.util.ratelimit.RateLimiterStrategy; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.test.junit5.MiniClusterExtension; +import org.apache.flink.util.Collector; import org.apache.flink.util.TestLogger; import org.junit.jupiter.api.Disabled; @@ -39,11 +40,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import java.util.stream.LongStream; -import static java.util.stream.Collectors.summingInt; import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -178,57 +177,48 @@ void testGatedRateLimiter() throws Exception { env.setParallelism(PARALLELISM); - int capacityPerSubtaskPerCycle = 2; - int capacityPerCycle = // avoid rounding errors when spreading records among subtasks - PARALLELISM * capacityPerSubtaskPerCycle; + int capacityPerSubtaskPerCheckpoint = 2; + int capacityPerCheckpoint = // avoid rounding errors when spreading records among subtasks + PARALLELISM * capacityPerSubtaskPerCheckpoint; final GeneratorFunction generatorFunction = index -> 1L; - // Allow each subtask to produce at least 3 cycles, gated by checkpoints - int count = capacityPerCycle * 3; + // produce slightly more elements than the checkpoint-rate-limit would allow + int count = capacityPerCheckpoint + 1; final DataGeneratorSource generatorSource = new DataGeneratorSource<>( generatorFunction, count, - RateLimiterStrategy.perCheckpoint(capacityPerCycle), + RateLimiterStrategy.perCheckpoint(capacityPerCheckpoint), Types.LONG); final DataStreamSource streamSource = env.fromSource(generatorSource, WatermarkStrategy.noWatermarks(), "Data Generator"); - final DataStream> map = - streamSource.map(new SubtaskAndCheckpointMapper()); - final List> results = map.executeAndCollect(1000); - - final Map, Integer> collect = - results.stream() - .collect( - Collectors.groupingBy( - x -> (new Tuple2<>(x.f0, x.f1)), summingInt(x -> 1))); - for (Map.Entry, Integer> entry : collect.entrySet()) { - assertThat(entry.getValue()).isEqualTo(capacityPerSubtaskPerCycle); - } + final DataStream map = streamSource.flatMap(new FirstCheckpointFilter()); + final List results = map.executeAndCollect(1000); + + assertThat(results).hasSize(capacityPerCheckpoint); } - private static class SubtaskAndCheckpointMapper - extends RichMapFunction> implements CheckpointListener { + private static class FirstCheckpointFilter + implements FlatMapFunction, CheckpointedFunction { - private long checkpointId = 0; - private int subtaskIndex; + private volatile boolean firstCheckpoint = true; @Override - public void open(Configuration parameters) { - subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + public void flatMap(Long value, Collector out) throws Exception { + if (firstCheckpoint) { + out.collect(value); + } } @Override - public Tuple2 map(Long value) { - return new Tuple2<>(subtaskIndex, checkpointId); + public void snapshotState(FunctionSnapshotContext context) throws Exception { + firstCheckpoint = false; } @Override - public void notifyCheckpointComplete(long checkpointId) { - this.checkpointId = checkpointId; - } + public void initializeState(FunctionInitializationContext context) throws Exception {} } private DataStream getGeneratorSourceStream( diff --git a/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceTest.java b/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceTest.java index 2caf93f6e08a8..ebda0c941ccfc 100644 --- a/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceTest.java +++ b/flink-connectors/flink-connector-datagen/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceTest.java @@ -86,10 +86,11 @@ void testRestoreEnumerator() throws Exception { @Test @DisplayName("Uses the underlying NumberSequenceSource correctly for checkpointing.") void testReaderCheckpoints() throws Exception { - final long from = 177; - final long mid = 333; - final long to = 563; - final long elementsPerCycle = (to - from) / 3; + final int numCycles = 3; + final long from = 0; + final long mid = 156; + final long to = 383; + final long elementsPerCycle = (to - from + 1) / numCycles; final TestingReaderOutput out = new TestingReaderOutput<>(); @@ -99,23 +100,48 @@ void testReaderCheckpoints() throws Exception { new NumberSequenceSource.NumberSequenceSplit("split-1", from, mid), new NumberSequenceSource.NumberSequenceSplit("split-2", mid + 1, to))); - long remainingInCycle = elementsPerCycle; - while (reader.pollNext(out) != InputStatus.END_OF_INPUT) { - if (--remainingInCycle <= 0) { - remainingInCycle = elementsPerCycle; - // checkpoint - List splits = reader.snapshotState(1L); - - // re-create and restore - reader = createReader(); - if (splits.isEmpty()) { - reader.notifyNoMoreSplits(); - } else { - reader.addSplits(splits); - } + for (int cycle = 0; cycle < numCycles; cycle++) { + // this call is not required but mimics what happens at runtime + assertThat(reader.pollNext(out)) + .as( + "Each poll should return a NOTHING_AVAILABLE status to explicitly trigger the availability check through in SourceReader.isAvailable") + .isSameAs(InputStatus.NOTHING_AVAILABLE); + for (int elementInCycle = 0; elementInCycle < elementsPerCycle; elementInCycle++) { + assertThat(reader.isAvailable()) + .as( + "There should be always data available because the test utilizes no rate-limiting strategy and splits are provided.") + .isCompleted(); + // this never returns END_OF_INPUT because IteratorSourceReaderBase#pollNext does + // not immediately return END_OF_INPUT when the input is exhausted + assertThat(reader.pollNext(out)) + .as( + "Each poll should return a NOTHING_AVAILABLE status to explicitly trigger the availability check through in SourceReader.isAvailable") + .isSameAs(InputStatus.NOTHING_AVAILABLE); + } + // checkpoint + List splits = reader.snapshotState(1L); + // first cycle partially consumes the first split + // second cycle consumes the remaining first split and partially consumes the second + // third cycle consumes remaining second split + assertThat(splits).hasSize(numCycles - cycle - 1); + + // re-create and restore + reader = createReader(); + if (splits.isEmpty()) { + reader.notifyNoMoreSplits(); + } else { + reader.addSplits(splits); } } + // we need to go again through isAvailable because IteratorSourceReaderBase#pollNext does + // not immediately return END_OF_INPUT when the input is exhausted + assertThat(reader.isAvailable()) + .as( + "There should be always data available because the test utilizes no rate-limiting strategy and splits are provided.") + .isCompleted(); + assertThat(reader.pollNext(out)).isSameAs(InputStatus.END_OF_INPUT); + final List result = out.getEmittedRecords(); final Iterable expected = LongStream.range(from, to + 1)::iterator; diff --git a/flink-core/src/main/java/org/apache/flink/api/connector/source/util/ratelimit/RateLimitedSourceReader.java b/flink-core/src/main/java/org/apache/flink/api/connector/source/util/ratelimit/RateLimitedSourceReader.java index 403ba36200c2e..aff9b5c266eb8 100644 --- a/flink-core/src/main/java/org/apache/flink/api/connector/source/util/ratelimit/RateLimitedSourceReader.java +++ b/flink-core/src/main/java/org/apache/flink/api/connector/source/util/ratelimit/RateLimitedSourceReader.java @@ -60,6 +60,10 @@ public void start() { @Override public InputStatus pollNext(ReaderOutput output) throws Exception { + if (availabilityFuture == null) { + // force isAvailable() to be called first to evaluate rate-limiting + return InputStatus.NOTHING_AVAILABLE; + } // reset future because the next record may hit the rate limit availabilityFuture = null; final InputStatus inputStatus = sourceReader.pollNext(output);