From 2782b7aaf1acc5cd318824467c8b664472aa4541 Mon Sep 17 00:00:00 2001 From: Davin Chia Date: Thu, 17 Nov 2022 13:26:52 -0800 Subject: [PATCH] Add AirbyteStreamNamespaceNamePair to MessageTracker Interface (#19361) Follow up to #19360. This PR adjusts the MessageTracker interface to use the new Pair object. --- .../general/DefaultReplicationWorker.java | 3 +- .../internal/AirbyteMessageTracker.java | 32 +++++++++---------- .../workers/internal/MessageTracker.java | 7 ++-- .../general/DefaultReplicationWorkerTest.java | 17 +++++++--- .../internal/AirbyteMessageTrackerTest.java | 29 +++++++++-------- .../main/resources/types/StreamSyncStats.yaml | 3 ++ 6 files changed, 51 insertions(+), 40 deletions(-) diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/general/DefaultReplicationWorker.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/general/DefaultReplicationWorker.java index f27a8de58c113..ec52aeadd94eb 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/general/DefaultReplicationWorker.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/general/DefaultReplicationWorker.java @@ -455,7 +455,8 @@ private List getPerStreamStats(ReplicationStatus outputStatus) syncStats.setRecordsCommitted(null); } return new StreamSyncStats() - .withStreamName(stream) + .withStreamName(stream.getName()) + .withStreamNamespace(stream.getNamespace()) .withStats(syncStats); }).collect(Collectors.toList()); } diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java index 44539d0e23f92..f7d8caa4be698 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java @@ -23,6 +23,7 @@ import io.airbyte.protocol.models.AirbyteRecordMessage; import io.airbyte.protocol.models.AirbyteStateMessage; import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; +import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair; import io.airbyte.protocol.models.AirbyteTraceMessage; import io.airbyte.workers.helper.FailureHelper; import io.airbyte.workers.internal.StateMetricsTracker.StateMetricsTrackerNoStateMatchException; @@ -33,6 +34,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -48,7 +50,7 @@ public class AirbyteMessageTracker implements MessageTracker { private final AtomicReference destinationOutputState; private final Map streamToRunningCount; private final HashFunction hashFunction; - private final BiMap streamNameToIndex; + private final BiMap nameNamespacePairToIndex; private final Map streamToTotalBytesEmitted; private final Map streamToTotalRecordsEmitted; private final StateDeltaTracker stateDeltaTracker; @@ -89,7 +91,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker, this.sourceOutputState = new AtomicReference<>(); this.destinationOutputState = new AtomicReference<>(); this.streamToRunningCount = new HashMap<>(); - this.streamNameToIndex = HashBiMap.create(); + this.nameNamespacePairToIndex = HashBiMap.create(); this.hashFunction = Hashing.murmur3_32_fixed(); this.streamToTotalBytesEmitted = new HashMap<>(); this.streamToTotalRecordsEmitted = new HashMap<>(); @@ -139,7 +141,7 @@ private void handleSourceEmittedRecord(final AirbyteRecordMessage recordMessage) stateMetricsTracker.setFirstRecordReceivedAt(LocalDateTime.now()); } - final short streamIndex = getStreamIndex(recordMessage.getStream()); + final short streamIndex = getStreamIndex(AirbyteStreamNameNamespacePair.fromRecordMessage(recordMessage)); final long currentRunningCount = streamToRunningCount.getOrDefault(streamIndex, 0L); streamToRunningCount.put(streamIndex, currentRunningCount + 1); @@ -269,12 +271,12 @@ private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceM } - private short getStreamIndex(final String streamName) { - if (!streamNameToIndex.containsKey(streamName)) { - streamNameToIndex.put(streamName, nextStreamIndex); + private short getStreamIndex(final AirbyteStreamNameNamespacePair pair) { + if (!nameNamespacePairToIndex.containsKey(pair)) { + nameNamespacePairToIndex.put(pair, nextStreamIndex); nextStreamIndex++; } - return streamNameToIndex.get(streamName); + return nameNamespacePairToIndex.get(pair); } private int getStateHashCode(final AirbyteStateMessage stateMessage) { @@ -347,36 +349,32 @@ public Optional getDestinationOutputState() { * because committed record counts cannot be reliably computed. */ @Override - public Optional> getStreamToCommittedRecords() { + public Optional> getStreamToCommittedRecords() { if (unreliableCommittedCounts) { return Optional.empty(); } final Map streamIndexToCommittedRecordCount = stateDeltaTracker.getStreamToCommittedRecords(); return Optional.of( streamIndexToCommittedRecordCount.entrySet().stream().collect( - Collectors.toMap( - entry -> streamNameToIndex.inverse().get(entry.getKey()), - Map.Entry::getValue))); + Collectors.toMap(entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue))); } /** * Swap out stream indices for stream names and return total records emitted by stream. */ @Override - public Map getStreamToEmittedRecords() { + public Map getStreamToEmittedRecords() { return streamToTotalRecordsEmitted.entrySet().stream().collect(Collectors.toMap( - entry -> streamNameToIndex.inverse().get(entry.getKey()), - Map.Entry::getValue)); + entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue)); } /** * Swap out stream indices for stream names and return total bytes emitted by stream. */ @Override - public Map getStreamToEmittedBytes() { + public Map getStreamToEmittedBytes() { return streamToTotalBytesEmitted.entrySet().stream().collect(Collectors.toMap( - entry -> streamNameToIndex.inverse().get(entry.getKey()), - Map.Entry::getValue)); + entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue)); } /** diff --git a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java index 86994fd785c85..09507ec7a374e 100644 --- a/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java +++ b/airbyte-commons-worker/src/main/java/io/airbyte/workers/internal/MessageTracker.java @@ -7,6 +7,7 @@ import io.airbyte.config.FailureReason; import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteMessage; +import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair; import io.airbyte.protocol.models.AirbyteTraceMessage; import java.util.Map; import java.util.Optional; @@ -55,7 +56,7 @@ public interface MessageTracker { * @return returns a map of committed record count by stream name. If committed record counts cannot * be computed, empty. */ - Optional> getStreamToCommittedRecords(); + Optional> getStreamToCommittedRecords(); /** * Get the per-stream emitted record count. This includes messages that were emitted by the source, @@ -63,7 +64,7 @@ public interface MessageTracker { * * @return returns a map of emitted record count by stream name. */ - Map getStreamToEmittedRecords(); + Map getStreamToEmittedRecords(); /** * Get the per-stream emitted byte count. This includes messages that were emitted by the source, @@ -71,7 +72,7 @@ public interface MessageTracker { * * @return returns a map of emitted record count by stream name. */ - Map getStreamToEmittedBytes(); + Map getStreamToEmittedBytes(); /** * Get the overall emitted record count. This includes messages that were emitted by the source, but diff --git a/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java b/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java index c40ce959f2822..7b915ed4943ca 100644 --- a/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java +++ b/airbyte-commons-worker/src/test/java/io/airbyte/workers/general/DefaultReplicationWorkerTest.java @@ -91,6 +91,8 @@ class DefaultReplicationWorkerTest { private static final AirbyteTraceMessage ERROR_TRACE_MESSAGE = AirbyteMessageUtils.createErrorTraceMessage("a connector error occurred", Double.valueOf(123)); private static final String STREAM1 = "stream1"; + + private static final String NAMESPACE = "namespace"; private static final String INDUCED_EXCEPTION = "induced exception"; private Path jobRoot; @@ -483,8 +485,9 @@ void testPopulatesOutputOnSuccess() throws WorkerException { when(messageTracker.getTotalBytesEmitted()).thenReturn(100L); when(messageTracker.getTotalSourceStateMessagesEmitted()).thenReturn(3L); when(messageTracker.getTotalDestinationStateMessagesEmitted()).thenReturn(1L); - when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(STREAM1, 100L)); - when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap(STREAM1, 12L)); + when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 100L)); + when(messageTracker.getStreamToEmittedRecords()) + .thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 12L)); when(messageTracker.getMaxSecondsToReceiveSourceStateMessage()).thenReturn(5L); when(messageTracker.getMeanSecondsToReceiveSourceStateMessage()).thenReturn(4L); when(messageTracker.getMaxSecondsBetweenStateMessageEmittedAndCommitted()).thenReturn(Optional.of(6L)); @@ -519,6 +522,7 @@ void testPopulatesOutputOnSuccess() throws WorkerException { .withStreamStats(Collections.singletonList( new StreamSyncStats() .withStreamName(STREAM1) + .withStreamNamespace(NAMESPACE) .withStats(new SyncStats() .withBytesEmitted(100L) .withRecordsEmitted(12L) @@ -599,9 +603,11 @@ void testPopulatesStatsOnFailureIfAvailable() throws Exception { when(messageTracker.getTotalRecordsCommitted()).thenReturn(Optional.of(6L)); when(messageTracker.getTotalSourceStateMessagesEmitted()).thenReturn(3L); when(messageTracker.getTotalDestinationStateMessagesEmitted()).thenReturn(2L); - when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(STREAM1, 100L)); - when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap(STREAM1, 12L)); - when(messageTracker.getStreamToCommittedRecords()).thenReturn(Optional.of(Collections.singletonMap(STREAM1, 6L))); + when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 100L)); + when(messageTracker.getStreamToEmittedRecords()) + .thenReturn(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 12L)); + when(messageTracker.getStreamToCommittedRecords()) + .thenReturn(Optional.of(Collections.singletonMap(new AirbyteStreamNameNamespacePair(STREAM1, NAMESPACE), 6L))); when(messageTracker.getMaxSecondsToReceiveSourceStateMessage()).thenReturn(10L); when(messageTracker.getMeanSecondsToReceiveSourceStateMessage()).thenReturn(8L); when(messageTracker.getMaxSecondsBetweenStateMessageEmittedAndCommitted()).thenReturn(Optional.of(12L)); @@ -631,6 +637,7 @@ void testPopulatesStatsOnFailureIfAvailable() throws Exception { final List expectedStreamStats = Collections.singletonList( new StreamSyncStats() .withStreamName(STREAM1) + .withStreamNamespace(NAMESPACE) .withStats(new SyncStats() .withBytesEmitted(100L) .withRecordsEmitted(12L) diff --git a/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java b/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java index 313debe985ebc..5123b299453ce 100644 --- a/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java +++ b/airbyte-commons-worker/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java @@ -11,6 +11,7 @@ import io.airbyte.config.FailureReason; import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteMessage; +import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair; import io.airbyte.workers.helper.FailureHelper; import io.airbyte.workers.internal.StateDeltaTracker.StateDeltaTrackerException; import io.airbyte.workers.internal.state_aggregator.StateAggregator; @@ -107,10 +108,10 @@ void testEmittedRecordsByStream() { messageTracker.acceptFromSource(r3); messageTracker.acceptFromSource(r3); - final Map expected = new HashMap<>(); - expected.put(STREAM_1, 1L); - expected.put(STREAM_2, 2L); - expected.put(STREAM_3, 3L); + final HashMap expected = new HashMap<>(); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 1L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 2L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), 3L); assertEquals(expected, messageTracker.getStreamToEmittedRecords()); } @@ -132,10 +133,10 @@ void testEmittedBytesByStream() { messageTracker.acceptFromSource(r3); messageTracker.acceptFromSource(r3); - final Map expected = new HashMap<>(); - expected.put(STREAM_1, r1Bytes); - expected.put(STREAM_2, r2Bytes * 2); - expected.put(STREAM_3, r3Bytes * 3); + final Map expected = new HashMap<>(); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), r1Bytes); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), r2Bytes * 2); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), r3Bytes * 3); assertEquals(expected, messageTracker.getStreamToEmittedBytes()); } @@ -160,14 +161,14 @@ void testGetCommittedRecordsByStream() { messageTracker.acceptFromSource(s2); // emit state 2 final Map countsByIndex = new HashMap<>(); - final Map expected = new HashMap<>(); + final Map expected = new HashMap<>(); Mockito.when(mStateDeltaTracker.getStreamToCommittedRecords()).thenReturn(countsByIndex); countsByIndex.put((short) 0, 1L); countsByIndex.put((short) 1, 2L); // result only contains counts up to state 1 - expected.put(STREAM_1, 1L); - expected.put(STREAM_2, 2L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 1L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 2L); assertEquals(expected, messageTracker.getStreamToCommittedRecords().get()); countsByIndex.clear(); @@ -177,9 +178,9 @@ void testGetCommittedRecordsByStream() { countsByIndex.put((short) 1, 3L); countsByIndex.put((short) 2, 1L); // result updated with counts between state 1 and state 2 - expected.put(STREAM_1, 3L); - expected.put(STREAM_2, 3L); - expected.put(STREAM_3, 1L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r1.getRecord()), 3L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r2.getRecord()), 3L); + expected.put(AirbyteStreamNameNamespacePair.fromRecordMessage(r3.getRecord()), 1L); assertEquals(expected, messageTracker.getStreamToCommittedRecords().get()); } diff --git a/airbyte-config/config-models/src/main/resources/types/StreamSyncStats.yaml b/airbyte-config/config-models/src/main/resources/types/StreamSyncStats.yaml index c20003f72c5dc..5ce73ce21d1d4 100644 --- a/airbyte-config/config-models/src/main/resources/types/StreamSyncStats.yaml +++ b/airbyte-config/config-models/src/main/resources/types/StreamSyncStats.yaml @@ -11,5 +11,8 @@ additionalProperties: false properties: streamName: type: string + # Not required as not all sources emits a namespace for each Stream. + streamNamespace: + type: string stats: "$ref": SyncStats.yaml