Skip to content

Commit

Permalink
Merge pull request #16657: [BEAM-13563] Restructure Kinesis Source fo…
Browse files Browse the repository at this point in the history
…r AWS 2
  • Loading branch information
aromanenko-dev committed Feb 2, 2022
2 parents b9c4919 + 1663d6d commit cc0b2c5
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 215 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -494,17 +494,12 @@ public Read withMaxCapacityPerShard(Integer maxCapacity) {

@Override
public PCollection<KinesisRecord> expand(PBegin input) {
checkArgument(getAWSClientsProvider() != null, "AWSClientsProvider is required");
checkArgument(getWatermarkPolicyFactory() != null, "WatermarkPolicyFactory is required");
checkArgument(getRateLimitPolicyFactory() != null, "RateLimitPolicyFactory is required");

Unbounded<KinesisRecord> unbounded =
org.apache.beam.sdk.io.Read.from(
new KinesisSource(
getAWSClientsProvider(),
getStreamName(),
getInitialPosition(),
getUpToDateThreshold(),
getWatermarkPolicyFactory(),
getRateLimitPolicyFactory(),
getRequestRecordsLimit(),
getMaxCapacityPerShard()));
org.apache.beam.sdk.io.Read.from(new KinesisSource(this));

PTransform<PBegin, PCollection<KinesisRecord>> transform = unbounded;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.IOException;
import java.util.NoSuchElementException;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Read;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand All @@ -39,62 +40,42 @@ class KinesisReader extends UnboundedSource.UnboundedReader<KinesisRecord> {

private static final Logger LOG = LoggerFactory.getLogger(KinesisReader.class);

private final Read spec;
private final SimplifiedKinesisClient kinesis;
private final KinesisSource source;
private final CheckpointGenerator initialCheckpointGenerator;
private final WatermarkPolicyFactory watermarkPolicyFactory;
private final RateLimitPolicyFactory rateLimitPolicyFactory;
private final Duration upToDateThreshold;

private final CheckpointGenerator checkpointGenerator;
private final Duration backlogBytesCheckThreshold;
private CustomOptional<KinesisRecord> currentRecord = CustomOptional.absent();
private long lastBacklogBytes;
private Instant backlogBytesLastCheckTime = new Instant(0L);
private ShardReadersPool shardReadersPool;
private final Integer maxCapacityPerShard;

KinesisReader(
Read spec,
SimplifiedKinesisClient kinesis,
CheckpointGenerator initialCheckpointGenerator,
KinesisSource source,
WatermarkPolicyFactory watermarkPolicyFactory,
RateLimitPolicyFactory rateLimitPolicyFactory,
Duration upToDateThreshold,
Integer maxCapacityPerShard) {
this(
kinesis,
initialCheckpointGenerator,
source,
watermarkPolicyFactory,
rateLimitPolicyFactory,
upToDateThreshold,
Duration.standardSeconds(30),
maxCapacityPerShard);
KinesisSource source) {
this(spec, kinesis, initialCheckpointGenerator, source, Duration.standardSeconds(30));
}

KinesisReader(
Read spec,
SimplifiedKinesisClient kinesis,
CheckpointGenerator initialCheckpointGenerator,
CheckpointGenerator checkpointGenerator,
KinesisSource source,
WatermarkPolicyFactory watermarkPolicyFactory,
RateLimitPolicyFactory rateLimitPolicyFactory,
Duration upToDateThreshold,
Duration backlogBytesCheckThreshold,
Integer maxCapacityPerShard) {
Duration backlogBytesCheckThreshold) {
this.spec = checkNotNull(spec, "spec");
this.kinesis = checkNotNull(kinesis, "kinesis");
this.initialCheckpointGenerator =
checkNotNull(initialCheckpointGenerator, "initialCheckpointGenerator");
this.watermarkPolicyFactory = watermarkPolicyFactory;
this.rateLimitPolicyFactory = rateLimitPolicyFactory;
this.checkpointGenerator = checkNotNull(checkpointGenerator, "checkpointGenerator");
this.source = source;
this.upToDateThreshold = upToDateThreshold;
this.backlogBytesCheckThreshold = backlogBytesCheckThreshold;
this.maxCapacityPerShard = maxCapacityPerShard;
}

/** Generates initial checkpoint and instantiates iterators for shards. */
@Override
public boolean start() throws IOException {
LOG.info("Starting reader using {}", initialCheckpointGenerator);
LOG.info("Starting reader using {}", checkpointGenerator);

try {
shardReadersPool = createShardReadersPool();
Expand Down Expand Up @@ -159,7 +140,8 @@ public UnboundedSource.CheckpointMark getCheckpointMark() {
* into account size of the records that were added to the stream after timestamp of the most
* recent record returned by the reader. If no records have yet been retrieved from the reader
* {@link UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} is returned. When currently processed
* record is not further behind than {@link #upToDateThreshold} then this method returns 0.
* record is not further behind than {@link Read#getUpToDateThreshold()} then this method returns
* 0.
*
* <p>The method can over-estimate size of the records for the split as it reports the backlog
* across all shards. This can lead to unnecessary decisions to scale up the number of workers but
Expand All @@ -172,51 +154,46 @@ public long getSplitBacklogBytes() {
Instant latestRecordTimestamp = shardReadersPool.getLatestRecordTimestamp();

if (latestRecordTimestamp.equals(BoundedWindow.TIMESTAMP_MIN_VALUE)) {
LOG.debug("Split backlog bytes for stream {} unknown", source.getStreamName());
LOG.debug("Split backlog bytes for stream {} unknown", spec.getStreamName());
return UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
}

if (latestRecordTimestamp.plus(upToDateThreshold).isAfterNow()) {
if (latestRecordTimestamp.plus(spec.getUpToDateThreshold()).isAfterNow()) {
LOG.debug(
"Split backlog bytes for stream {} with latest record timestamp {}: 0 (latest record timestamp is up-to-date with threshold of {})",
source.getStreamName(),
spec.getStreamName(),
latestRecordTimestamp,
upToDateThreshold);
spec.getUpToDateThreshold());
return 0L;
}

if (backlogBytesLastCheckTime.plus(backlogBytesCheckThreshold).isAfterNow()) {
LOG.debug(
"Split backlog bytes for {} stream with latest record timestamp {}: {} (cached value)",
source.getStreamName(),
spec.getStreamName(),
latestRecordTimestamp,
lastBacklogBytes);
return lastBacklogBytes;
}

try {
lastBacklogBytes = kinesis.getBacklogBytes(source.getStreamName(), latestRecordTimestamp);
lastBacklogBytes = kinesis.getBacklogBytes(spec.getStreamName(), latestRecordTimestamp);
backlogBytesLastCheckTime = Instant.now();
} catch (TransientKinesisException e) {
LOG.warn(
"Transient exception occurred during backlog estimation for stream {}.",
source.getStreamName(),
spec.getStreamName(),
e);
}
LOG.info(
"Split backlog bytes for {} stream with {} latest record timestamp: {}",
source.getStreamName(),
spec.getStreamName(),
latestRecordTimestamp,
lastBacklogBytes);
return lastBacklogBytes;
}

ShardReadersPool createShardReadersPool() throws TransientKinesisException {
return new ShardReadersPool(
kinesis,
initialCheckpointGenerator.generate(kinesis),
watermarkPolicyFactory,
rateLimitPolicyFactory,
maxCapacityPerShard);
return new ShardReadersPool(spec, kinesis, checkpointGenerator.generate(kinesis));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Read;
import org.apache.beam.sdk.options.PipelineOptions;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -37,53 +37,21 @@ class KinesisSource extends UnboundedSource<KinesisRecord, KinesisReaderCheckpoi

private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class);

private final AWSClientsProvider awsClientsProvider;
private final String streamName;
private final Duration upToDateThreshold;
private final WatermarkPolicyFactory watermarkPolicyFactory;
private final RateLimitPolicyFactory rateLimitPolicyFactory;
private CheckpointGenerator initialCheckpointGenerator;
private final Integer limit;
private final Integer maxCapacityPerShard;

KinesisSource(
AWSClientsProvider awsClientsProvider,
String streamName,
StartingPoint startingPoint,
Duration upToDateThreshold,
WatermarkPolicyFactory watermarkPolicyFactory,
RateLimitPolicyFactory rateLimitPolicyFactory,
Integer limit,
Integer maxCapacityPerShard) {
this(
awsClientsProvider,
new DynamicCheckpointGenerator(streamName, startingPoint),
streamName,
upToDateThreshold,
watermarkPolicyFactory,
rateLimitPolicyFactory,
limit,
maxCapacityPerShard);
private final Read spec;
private final CheckpointGenerator checkpointGenerator;

KinesisSource(Read read) {
this(read, new DynamicCheckpointGenerator(read.getStreamName(), read.getInitialPosition()));
}

private KinesisSource(Read spec, CheckpointGenerator initialCheckpoint) {
this.spec = checkNotNull(spec);
this.checkpointGenerator = checkNotNull(initialCheckpoint);
}

private KinesisSource(
AWSClientsProvider awsClientsProvider,
CheckpointGenerator initialCheckpoint,
String streamName,
Duration upToDateThreshold,
WatermarkPolicyFactory watermarkPolicyFactory,
RateLimitPolicyFactory rateLimitPolicyFactory,
Integer limit,
Integer maxCapacityPerShard) {
this.awsClientsProvider = awsClientsProvider;
this.initialCheckpointGenerator = initialCheckpoint;
this.streamName = streamName;
this.upToDateThreshold = upToDateThreshold;
this.watermarkPolicyFactory = watermarkPolicyFactory;
this.rateLimitPolicyFactory = rateLimitPolicyFactory;
this.limit = limit;
this.maxCapacityPerShard = maxCapacityPerShard;
validate();
private SimplifiedKinesisClient createClient() {
return SimplifiedKinesisClient.from(
spec.getAWSClientsProvider(), spec.getRequestRecordsLimit());
}

/**
Expand All @@ -92,23 +60,12 @@ private KinesisSource(
*/
@Override
public List<KinesisSource> split(int desiredNumSplits, PipelineOptions options) throws Exception {
KinesisReaderCheckpoint checkpoint =
initialCheckpointGenerator.generate(
SimplifiedKinesisClient.from(awsClientsProvider, limit));
KinesisReaderCheckpoint checkpoint = checkpointGenerator.generate(createClient());

List<KinesisSource> sources = newArrayList();

for (KinesisReaderCheckpoint partition : checkpoint.splitInto(desiredNumSplits)) {
sources.add(
new KinesisSource(
awsClientsProvider,
new StaticCheckpointGenerator(partition),
streamName,
upToDateThreshold,
watermarkPolicyFactory,
rateLimitPolicyFactory,
limit,
maxCapacityPerShard));
sources.add(new KinesisSource(spec, new StaticCheckpointGenerator(partition)));
}
return sources;
}
Expand All @@ -122,43 +79,22 @@ public List<KinesisSource> split(int desiredNumSplits, PipelineOptions options)
public UnboundedReader<KinesisRecord> createReader(
PipelineOptions options, KinesisReaderCheckpoint checkpointMark) {

CheckpointGenerator checkpointGenerator = initialCheckpointGenerator;

CheckpointGenerator checkpointGenerator = this.checkpointGenerator;
if (checkpointMark != null) {
checkpointGenerator = new StaticCheckpointGenerator(checkpointMark);
}

LOG.info("Creating new reader using {}", checkpointGenerator);

return new KinesisReader(
SimplifiedKinesisClient.from(awsClientsProvider, limit),
checkpointGenerator,
this,
watermarkPolicyFactory,
rateLimitPolicyFactory,
upToDateThreshold,
maxCapacityPerShard);
return new KinesisReader(spec, createClient(), checkpointGenerator, this);
}

@Override
public Coder<KinesisReaderCheckpoint> getCheckpointMarkCoder() {
return SerializableCoder.of(KinesisReaderCheckpoint.class);
}

@Override
public void validate() {
checkNotNull(awsClientsProvider);
checkNotNull(initialCheckpointGenerator);
checkNotNull(watermarkPolicyFactory);
checkNotNull(rateLimitPolicyFactory);
}

@Override
public Coder<KinesisRecord> getOutputCoder() {
return KinesisRecordCoder.of();
}

String getStreamName() {
return streamName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Read;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -82,24 +83,16 @@ class ShardReadersPool {
/** A map for keeping the current number of records stored in a buffer per shard. */
private final ConcurrentMap<String, AtomicInteger> numberOfRecordsInAQueueByShard;

private final Read read;
private final SimplifiedKinesisClient kinesis;
private final WatermarkPolicyFactory watermarkPolicyFactory;
private final RateLimitPolicyFactory rateLimitPolicyFactory;
private final KinesisReaderCheckpoint initialCheckpoint;
private final int queueCapacityPerShard;
private final AtomicBoolean poolOpened = new AtomicBoolean(true);

ShardReadersPool(
SimplifiedKinesisClient kinesis,
KinesisReaderCheckpoint initialCheckpoint,
WatermarkPolicyFactory watermarkPolicyFactory,
RateLimitPolicyFactory rateLimitPolicyFactory,
int queueCapacityPerShard) {
Read read, SimplifiedKinesisClient kinesis, KinesisReaderCheckpoint initialCheckpoint) {
this.read = read;
this.kinesis = kinesis;
this.initialCheckpoint = initialCheckpoint;
this.watermarkPolicyFactory = watermarkPolicyFactory;
this.rateLimitPolicyFactory = rateLimitPolicyFactory;
this.queueCapacityPerShard = queueCapacityPerShard;
this.executorService = Executors.newCachedThreadPool();
this.numberOfRecordsInAQueueByShard = new ConcurrentHashMap<>();
this.shardIteratorsMap = new AtomicReference<>();
Expand All @@ -113,7 +106,7 @@ void start() throws TransientKinesisException {
shardIteratorsMap.set(shardsMap.build());
if (!shardIteratorsMap.get().isEmpty()) {
recordsQueue =
new ArrayBlockingQueue<>(queueCapacityPerShard * shardIteratorsMap.get().size());
new ArrayBlockingQueue<>(read.getMaxCapacityPerShard() * shardIteratorsMap.get().size());
String streamName = initialCheckpoint.getStreamName();
startReadingShards(shardIteratorsMap.get().values(), streamName);
} else {
Expand All @@ -136,8 +129,8 @@ void startReadingShards(Iterable<ShardRecordsIterator> shardRecordsIterators, St
getShardIdsFromRecordsIterators(shardRecordsIterators));
for (final ShardRecordsIterator recordsIterator : shardRecordsIterators) {
numberOfRecordsInAQueueByShard.put(recordsIterator.getShardId(), new AtomicInteger());
executorService.submit(
() -> readLoop(recordsIterator, rateLimitPolicyFactory.getRateLimitPolicy()));
RateLimitPolicy policy = read.getRateLimitPolicyFactory().getRateLimitPolicy();
executorService.submit(() -> readLoop(recordsIterator, policy));
}
}

Expand Down Expand Up @@ -283,7 +276,7 @@ KinesisReaderCheckpoint getCheckpointMark() {
ShardRecordsIterator createShardIterator(
SimplifiedKinesisClient kinesis, ShardCheckpoint checkpoint)
throws TransientKinesisException {
return new ShardRecordsIterator(checkpoint, kinesis, watermarkPolicyFactory);
return new ShardRecordsIterator(checkpoint, kinesis, read.getWatermarkPolicyFactory());
}

/**
Expand Down
Loading

0 comments on commit cc0b2c5

Please sign in to comment.