Skip to content

Commit

Permalink
Add AirbyteStreamNamespaceNamePair to MessageTracker Interface (#19361)
Browse files Browse the repository at this point in the history
Follow up to #19360.

This PR adjusts the MessageTracker interface to use the new Pair object.
  • Loading branch information
davinchia committed Nov 17, 2022
1 parent 90350c1 commit 2782b7a
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ private List<StreamSyncStats> getPerStreamStats(ReplicationStatus outputStatus)
syncStats.setRecordsCommitted(null);
}
return new StreamSyncStats()
.withStreamName(stream)
.withStreamName(stream.getName())
.withStreamNamespace(stream.getNamespace())
.withStats(syncStats);
}).collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.airbyte.protocol.models.AirbyteRecordMessage;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.AirbyteTraceMessage;
import io.airbyte.workers.helper.FailureHelper;
import io.airbyte.workers.internal.StateMetricsTracker.StateMetricsTrackerNoStateMatchException;
Expand All @@ -33,6 +34,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand All @@ -48,7 +50,7 @@ public class AirbyteMessageTracker implements MessageTracker {
private final AtomicReference<State> destinationOutputState;
private final Map<Short, Long> streamToRunningCount;
private final HashFunction hashFunction;
private final BiMap<String, Short> streamNameToIndex;
private final BiMap<AirbyteStreamNameNamespacePair, Short> nameNamespacePairToIndex;
private final Map<Short, Long> streamToTotalBytesEmitted;
private final Map<Short, Long> streamToTotalRecordsEmitted;
private final StateDeltaTracker stateDeltaTracker;
Expand Down Expand Up @@ -89,7 +91,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker,
this.sourceOutputState = new AtomicReference<>();
this.destinationOutputState = new AtomicReference<>();
this.streamToRunningCount = new HashMap<>();
this.streamNameToIndex = HashBiMap.create();
this.nameNamespacePairToIndex = HashBiMap.create();
this.hashFunction = Hashing.murmur3_32_fixed();
this.streamToTotalBytesEmitted = new HashMap<>();
this.streamToTotalRecordsEmitted = new HashMap<>();
Expand Down Expand Up @@ -139,7 +141,7 @@ private void handleSourceEmittedRecord(final AirbyteRecordMessage recordMessage)
stateMetricsTracker.setFirstRecordReceivedAt(LocalDateTime.now());
}

final short streamIndex = getStreamIndex(recordMessage.getStream());
final short streamIndex = getStreamIndex(AirbyteStreamNameNamespacePair.fromRecordMessage(recordMessage));

final long currentRunningCount = streamToRunningCount.getOrDefault(streamIndex, 0L);
streamToRunningCount.put(streamIndex, currentRunningCount + 1);
Expand Down Expand Up @@ -269,12 +271,12 @@ private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceM

}

private short getStreamIndex(final String streamName) {
if (!streamNameToIndex.containsKey(streamName)) {
streamNameToIndex.put(streamName, nextStreamIndex);
private short getStreamIndex(final AirbyteStreamNameNamespacePair pair) {
if (!nameNamespacePairToIndex.containsKey(pair)) {
nameNamespacePairToIndex.put(pair, nextStreamIndex);
nextStreamIndex++;
}
return streamNameToIndex.get(streamName);
return nameNamespacePairToIndex.get(pair);
}

private int getStateHashCode(final AirbyteStateMessage stateMessage) {
Expand Down Expand Up @@ -347,36 +349,32 @@ public Optional<State> getDestinationOutputState() {
* because committed record counts cannot be reliably computed.
*/
@Override
public Optional<Map<String, Long>> getStreamToCommittedRecords() {
public Optional<Map<AirbyteStreamNameNamespacePair, Long>> getStreamToCommittedRecords() {
if (unreliableCommittedCounts) {
return Optional.empty();
}
final Map<Short, Long> streamIndexToCommittedRecordCount = stateDeltaTracker.getStreamToCommittedRecords();
return Optional.of(
streamIndexToCommittedRecordCount.entrySet().stream().collect(
Collectors.toMap(
entry -> streamNameToIndex.inverse().get(entry.getKey()),
Map.Entry::getValue)));
Collectors.toMap(entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue)));
}

/**
* Swap out stream indices for stream names and return total records emitted by stream.
*/
@Override
public Map<String, Long> getStreamToEmittedRecords() {
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords() {
return streamToTotalRecordsEmitted.entrySet().stream().collect(Collectors.toMap(
entry -> streamNameToIndex.inverse().get(entry.getKey()),
Map.Entry::getValue));
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
* Swap out stream indices for stream names and return total bytes emitted by stream.
*/
@Override
public Map<String, Long> getStreamToEmittedBytes() {
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes() {
return streamToTotalBytesEmitted.entrySet().stream().collect(Collectors.toMap(
entry -> streamNameToIndex.inverse().get(entry.getKey()),
Map.Entry::getValue));
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.airbyte.config.FailureReason;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.AirbyteTraceMessage;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -55,23 +56,23 @@ public interface MessageTracker {
* @return returns a map of committed record count by stream name. If committed record counts cannot
* be computed, empty.
*/
Optional<Map<String, Long>> getStreamToCommittedRecords();
Optional<Map<AirbyteStreamNameNamespacePair, Long>> getStreamToCommittedRecords();

/**
* Get the per-stream emitted record count. This includes messages that were emitted by the source,
* but never committed by the destination.
*
* @return returns a map of emitted record count by stream name.
*/
Map<String, Long> getStreamToEmittedRecords();
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords();

/**
* Get the per-stream emitted byte count. This includes messages that were emitted by the source,
* but never committed by the destination.
*
* @return returns a map of emitted record count by stream name.
*/
Map<String, Long> getStreamToEmittedBytes();
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes();

/**
* Get the overall emitted record count. This includes messages that were emitted by the source, but
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class DefaultReplicationWorkerTest {
private static final AirbyteTraceMessage ERROR_TRACE_MESSAGE =
AirbyteMessageUtils.createErrorTraceMessage("a connector error occurred", Double.valueOf(123));
private static final String STREAM1 = "stream1";

private static final String NAMESPACE = "namespace";
private static final String INDUCED_EXCEPTION = "induced exception";

private Path jobRoot;
Expand Down Expand Up @@ -483,8 +485,9 @@ void testPopulatesOutputOnSuccess() throws WorkerException {
when(messageTracker.getTotalBytesEmitted()).thenReturn(100L);
when(messageTracker.getTotalSourceStateMessagesEmitted()).thenReturn(3L);
when(messageTracker.getTotalDestinationStateMessagesEmitted()).thenReturn(1L);
when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(STREAM1, 100L));
when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap(STREAM1, 12L));
when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 100L));
when(messageTracker.getStreamToEmittedRecords())
.thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 12L));
when(messageTracker.getMaxSecondsToReceiveSourceStateMessage()).thenReturn(5L);
when(messageTracker.getMeanSecondsToReceiveSourceStateMessage()).thenReturn(4L);
when(messageTracker.getMaxSecondsBetweenStateMessageEmittedAndCommitted()).thenReturn(Optional.of(6L));
Expand Down Expand Up @@ -519,6 +522,7 @@ void testPopulatesOutputOnSuccess() throws WorkerException {
.withStreamStats(Collections.singletonList(
new StreamSyncStats()
.withStreamName(STREAM1)
.withStreamNamespace(NAMESPACE)
.withStats(new SyncStats()
.withBytesEmitted(100L)
.withRecordsEmitted(12L)
Expand Down Expand Up @@ -599,9 +603,11 @@ void testPopulatesStatsOnFailureIfAvailable() throws Exception {
when(messageTracker.getTotalRecordsCommitted()).thenReturn(Optional.of(6L));
when(messageTracker.getTotalSourceStateMessagesEmitted()).thenReturn(3L);
when(messageTracker.getTotalDestinationStateMessagesEmitted()).thenReturn(2L);
when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(STREAM1, 100L));
when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap(STREAM1, 12L));
when(messageTracker.getStreamToCommittedRecords()).thenReturn(Optional.of(Collections.singletonMap(STREAM1, 6L)));
when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 100L));
when(messageTracker.getStreamToEmittedRecords())
.thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 12L));
when(messageTracker.getStreamToCommittedRecords())
.thenReturn(Optional.of(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 6L)));
when(messageTracker.getMaxSecondsToReceiveSourceStateMessage()).thenReturn(10L);
when(messageTracker.getMeanSecondsToReceiveSourceStateMessage()).thenReturn(8L);
when(messageTracker.getMaxSecondsBetweenStateMessageEmittedAndCommitted()).thenReturn(Optional.of(12L));
Expand Down Expand Up @@ -631,6 +637,7 @@ void testPopulatesStatsOnFailureIfAvailable() throws Exception {
final List<StreamSyncStats> expectedStreamStats = Collections.singletonList(
new StreamSyncStats()
.withStreamName(STREAM1)
.withStreamNamespace(NAMESPACE)
.withStats(new SyncStats()
.withBytesEmitted(100L)
.withRecordsEmitted(12L)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.airbyte.config.FailureReason;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.workers.helper.FailureHelper;
import io.airbyte.workers.internal.StateDeltaTracker.StateDeltaTrackerException;
import io.airbyte.workers.internal.state_aggregator.StateAggregator;
Expand Down Expand Up @@ -107,10 +108,10 @@ void testEmittedRecordsByStream() {
messageTracker.acceptFromSource(r3);
messageTracker.acceptFromSource(r3);

final Map<String, Long> expected = new HashMap<>();
expected.put(STREAM_1, 1L);
expected.put(STREAM_2, 2L);
expected.put(STREAM_3, 3L);
final HashMap<AirbyteStreamNameNamespacePair, Long> expected = new HashMap<>();
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 1L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 2L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), 3L);

assertEquals(expected, messageTracker.getStreamToEmittedRecords());
}
Expand All @@ -132,10 +133,10 @@ void testEmittedBytesByStream() {
messageTracker.acceptFromSource(r3);
messageTracker.acceptFromSource(r3);

final Map<String, Long> expected = new HashMap<>();
expected.put(STREAM_1, r1Bytes);
expected.put(STREAM_2, r2Bytes * 2);
expected.put(STREAM_3, r3Bytes * 3);
final Map<AirbyteStreamNameNamespacePair, Long> expected = new HashMap<>();
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), r1Bytes);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), r2Bytes * 2);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), r3Bytes * 3);

assertEquals(expected, messageTracker.getStreamToEmittedBytes());
}
Expand All @@ -160,14 +161,14 @@ void testGetCommittedRecordsByStream() {
messageTracker.acceptFromSource(s2); // emit state 2

final Map<Short, Long> countsByIndex = new HashMap<>();
final Map<String, Long> expected = new HashMap<>();
final Map<AirbyteStreamNameNamespacePair, Long> expected = new HashMap<>();
Mockito.when(mStateDeltaTracker.getStreamToCommittedRecords()).thenReturn(countsByIndex);

countsByIndex.put((short) 0, 1L);
countsByIndex.put((short) 1, 2L);
// result only contains counts up to state 1
expected.put(STREAM_1, 1L);
expected.put(STREAM_2, 2L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 1L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 2L);
assertEquals(expected, messageTracker.getStreamToCommittedRecords().get());

countsByIndex.clear();
Expand All @@ -177,9 +178,9 @@ void testGetCommittedRecordsByStream() {
countsByIndex.put((short) 1, 3L);
countsByIndex.put((short) 2, 1L);
// result updated with counts between state 1 and state 2
expected.put(STREAM_1, 3L);
expected.put(STREAM_2, 3L);
expected.put(STREAM_3, 1L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 3L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 3L);
expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), 1L);
assertEquals(expected, messageTracker.getStreamToCommittedRecords().get());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@ additionalProperties: false
properties:
streamName:
type: string
# Not required as not all sources emits a namespace for each Stream.
streamNamespace:
type: string
stats:
"$ref": SyncStats.yaml

0 comments on commit 2782b7a

Please sign in to comment.