Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
package org.apache.flink.connector.datagen.source;

import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.connector.source.SourceReaderContext;
import org.apache.flink.api.connector.source.util.ratelimit.RateLimiterStrategy;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.test.junit5.MiniClusterExtension;
import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;

import org.junit.jupiter.api.Disabled;
Expand All @@ -39,11 +40,9 @@
import org.junit.jupiter.api.extension.RegisterExtension;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;

import static java.util.stream.Collectors.summingInt;
import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
Expand Down Expand Up @@ -178,57 +177,48 @@ void testGatedRateLimiter() throws Exception {

env.setParallelism(PARALLELISM);

int capacityPerSubtaskPerCycle = 2;
int capacityPerCycle = // avoid rounding errors when spreading records among subtasks
PARALLELISM * capacityPerSubtaskPerCycle;
int capacityPerSubtaskPerCheckpoint = 2;
int capacityPerCheckpoint = // avoid rounding errors when spreading records among subtasks
PARALLELISM * capacityPerSubtaskPerCheckpoint;

final GeneratorFunction<Long, Long> generatorFunction = index -> 1L;

// Allow each subtask to produce at least 3 cycles, gated by checkpoints
int count = capacityPerCycle * 3;
// produce slightly more elements than the checkpoint-rate-limit would allow
int count = capacityPerCheckpoint + 1;
final DataGeneratorSource<Long> generatorSource =
new DataGeneratorSource<>(
generatorFunction,
count,
RateLimiterStrategy.perCheckpoint(capacityPerCycle),
RateLimiterStrategy.perCheckpoint(capacityPerCheckpoint),
Types.LONG);

final DataStreamSource<Long> streamSource =
env.fromSource(generatorSource, WatermarkStrategy.noWatermarks(), "Data Generator");
final DataStream<Tuple2<Integer, Long>> map =
streamSource.map(new SubtaskAndCheckpointMapper());
final List<Tuple2<Integer, Long>> results = map.executeAndCollect(1000);

final Map<Tuple2<Integer, Long>, Integer> collect =
results.stream()
.collect(
Collectors.groupingBy(
x -> (new Tuple2<>(x.f0, x.f1)), summingInt(x -> 1)));
for (Map.Entry<Tuple2<Integer, Long>, Integer> entry : collect.entrySet()) {
assertThat(entry.getValue()).isEqualTo(capacityPerSubtaskPerCycle);
}
final DataStream<Long> map = streamSource.flatMap(new FirstCheckpointFilter());
final List<Long> results = map.executeAndCollect(1000);

assertThat(results).hasSize(capacityPerCheckpoint);
}

private static class SubtaskAndCheckpointMapper
extends RichMapFunction<Long, Tuple2<Integer, Long>> implements CheckpointListener {
private static class FirstCheckpointFilter
implements FlatMapFunction<Long, Long>, CheckpointedFunction {

private long checkpointId = 0;
private int subtaskIndex;
private volatile boolean firstCheckpoint = true;

@Override
public void open(Configuration parameters) {
subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
public void flatMap(Long value, Collector<Long> out) throws Exception {
if (firstCheckpoint) {
out.collect(value);
}
}

@Override
public Tuple2<Integer, Long> map(Long value) {
return new Tuple2<>(subtaskIndex, checkpointId);
public void snapshotState(FunctionSnapshotContext context) throws Exception {
firstCheckpoint = false;
}

@Override
public void notifyCheckpointComplete(long checkpointId) {
this.checkpointId = checkpointId;
}
public void initializeState(FunctionInitializationContext context) throws Exception {}
}

private DataStream<Long> getGeneratorSourceStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ void testRestoreEnumerator() throws Exception {
@Test
@DisplayName("Uses the underlying NumberSequenceSource correctly for checkpointing.")
void testReaderCheckpoints() throws Exception {
final long from = 177;
final long mid = 333;
final long to = 563;
final long elementsPerCycle = (to - from) / 3;
final int numCycles = 3;
final long from = 0;
final long mid = 156;
final long to = 383;
final long elementsPerCycle = (to - from + 1) / numCycles;

final TestingReaderOutput<Long> out = new TestingReaderOutput<>();

Expand All @@ -99,23 +100,48 @@ void testReaderCheckpoints() throws Exception {
new NumberSequenceSource.NumberSequenceSplit("split-1", from, mid),
new NumberSequenceSource.NumberSequenceSplit("split-2", mid + 1, to)));

long remainingInCycle = elementsPerCycle;
while (reader.pollNext(out) != InputStatus.END_OF_INPUT) {
if (--remainingInCycle <= 0) {
remainingInCycle = elementsPerCycle;
// checkpoint
List<NumberSequenceSource.NumberSequenceSplit> splits = reader.snapshotState(1L);

// re-create and restore
reader = createReader();
if (splits.isEmpty()) {
reader.notifyNoMoreSplits();
} else {
reader.addSplits(splits);
}
for (int cycle = 0; cycle < numCycles; cycle++) {
// this call is not required but mimics what happens at runtime
assertThat(reader.pollNext(out))
.as(
"Each poll should return a NOTHING_AVAILABLE status to explicitly trigger the availability check through in SourceReader.isAvailable")
.isSameAs(InputStatus.NOTHING_AVAILABLE);
for (int elementInCycle = 0; elementInCycle < elementsPerCycle; elementInCycle++) {
assertThat(reader.isAvailable())
.as(
"There should be always data available because the test utilizes no rate-limiting strategy and splits are provided.")
.isCompleted();
// this never returns END_OF_INPUT because IteratorSourceReaderBase#pollNext does
// not immediately return END_OF_INPUT when the input is exhausted
assertThat(reader.pollNext(out))
.as(
"Each poll should return a NOTHING_AVAILABLE status to explicitly trigger the availability check through in SourceReader.isAvailable")
.isSameAs(InputStatus.NOTHING_AVAILABLE);
}
// checkpoint
List<NumberSequenceSource.NumberSequenceSplit> splits = reader.snapshotState(1L);
// first cycle partially consumes the first split
// second cycle consumes the remaining first split and partially consumes the second
// third cycle consumes remaining second split
assertThat(splits).hasSize(numCycles - cycle - 1);

// re-create and restore
reader = createReader();
if (splits.isEmpty()) {
reader.notifyNoMoreSplits();
} else {
reader.addSplits(splits);
}
}

// we need to go again through isAvailable because IteratorSourceReaderBase#pollNext does
// not immediately return END_OF_INPUT when the input is exhausted
assertThat(reader.isAvailable())
.as(
"There should be always data available because the test utilizes no rate-limiting strategy and splits are provided.")
.isCompleted();
assertThat(reader.pollNext(out)).isSameAs(InputStatus.END_OF_INPUT);

final List<Long> result = out.getEmittedRecords();
final Iterable<Long> expected = LongStream.range(from, to + 1)::iterator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public void start() {

@Override
public InputStatus pollNext(ReaderOutput<E> output) throws Exception {
if (availabilityFuture == null) {
// force isAvailable() to be called first to evaluate rate-limiting
return InputStatus.NOTHING_AVAILABLE;
}
// reset future because the next record may hit the rate limit
availabilityFuture = null;
final InputStatus inputStatus = sourceReader.pollNext(output);
Expand Down