diff --git a/airbyte_cdk/manifest_server/command_processor/processor.py b/airbyte_cdk/manifest_server/command_processor/processor.py index 16d14a799..166d5e391 100644 --- a/airbyte_cdk/manifest_server/command_processor/processor.py +++ b/airbyte_cdk/manifest_server/command_processor/processor.py @@ -41,7 +41,6 @@ def test_read( """ Test the read method of the source. """ - test_read_handler = TestReader( max_pages_per_slice=page_limit, max_slices=slice_limit, diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 809936ae0..47c32d1cc 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -14,10 +14,21 @@ from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.utils.slice_hasher import SliceHasher + # For Connector Builder test read operations, we track the total number of records -# read for the stream at the global level so that we can stop reading early if we -# exceed the record limit -total_record_counter = 0 +# read for the stream so that we can stop reading early if we exceed the record limit. +class RecordCounter: + def __init__(self) -> None: + self.total_record_counter = 0 + + def increment(self) -> None: + self.total_record_counter += 1 + + def reset(self) -> None: + self.total_record_counter = 0 + + def get_total_records(self) -> int: + return self.total_record_counter class SchemaLoaderCachingDecorator(SchemaLoader): @@ -51,6 +62,7 @@ def __init__( self._retriever = retriever self._message_repository = message_repository self._max_records_limit = max_records_limit + self._record_counter = RecordCounter() def create(self, stream_slice: StreamSlice) -> Partition: return DeclarativePartition( @@ -60,6 +72,7 @@ def create(self, stream_slice: StreamSlice) -> Partition: message_repository=self._message_repository, max_records_limit=self._max_records_limit, stream_slice=stream_slice, + record_counter=self._record_counter, ) @@ -72,6 +85,7 @@ def __init__( message_repository: MessageRepository, max_records_limit: Optional[int], stream_slice: StreamSlice, + record_counter: RecordCounter, ): self._stream_name = stream_name self._schema_loader = schema_loader @@ -80,17 +94,17 @@ def __init__( self._max_records_limit = max_records_limit self._stream_slice = stream_slice self._hash = SliceHasher.hash(self._stream_name, self._stream_slice) + self._record_counter = record_counter def read(self) -> Iterable[Record]: if self._max_records_limit is not None: - global total_record_counter - if total_record_counter >= self._max_records_limit: + if self._record_counter.get_total_records() >= self._max_records_limit: return for stream_data in self._retriever.read_records( self._schema_loader.get_json_schema(), self._stream_slice ): if self._max_records_limit is not None: - if total_record_counter >= self._max_records_limit: + if self._record_counter.get_total_records() >= self._max_records_limit: break if isinstance(stream_data, Mapping): @@ -108,7 +122,7 @@ def read(self) -> Iterable[Record]: self._message_repository.emit_message(stream_data) if self._max_records_limit is not None: - total_record_counter += 1 + self._record_counter.increment() def to_slice(self) -> Optional[Mapping[str, Any]]: return self._stream_slice diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index cb774bda7..fb38bb343 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -23,6 +23,7 @@ from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( DeclarativePartition, + RecordCounter, ) from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( @@ -3624,6 +3625,7 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update message_repository=MagicMock(), max_records_limit=None, stream_slice=slice, + record_counter=RecordCounter(), ) ) @@ -3709,6 +3711,7 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update(): message_repository=MagicMock(), max_records_limit=None, stream_slice=slice, + record_counter=RecordCounter(), ) ) cursor.ensure_at_least_one_state_emitted() @@ -3804,6 +3807,7 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update message_repository=MagicMock(), max_records_limit=None, stream_slice=slice, + record_counter=RecordCounter(), ) ) cursor.ensure_at_least_one_state_emitted() @@ -3894,6 +3898,7 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi message_repository=MagicMock(), max_records_limit=None, stream_slice=slice, + record_counter=RecordCounter(), ) ) @@ -3968,6 +3973,7 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob message_repository=MagicMock(), max_records_limit=None, stream_slice=slice, + record_counter=RecordCounter(), ) ) cursor.ensure_at_least_one_state_emitted() @@ -4053,6 +4059,7 @@ def test_semaphore_cleanup(): message_repository=MagicMock(), max_records_limit=None, stream_slice=s, + record_counter=RecordCounter(), ) ) @@ -4173,6 +4180,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted(): message_repository=MagicMock(), max_records_limit=None, stream_slice=first_1, + record_counter=RecordCounter(), ) ) @@ -4185,6 +4193,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted(): message_repository=MagicMock(), max_records_limit=None, stream_slice=two, + record_counter=RecordCounter(), ) ) @@ -4197,6 +4206,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted(): message_repository=MagicMock(), max_records_limit=None, stream_slice=second_1, + record_counter=RecordCounter(), ) ) @@ -4258,6 +4268,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists(): message_repository=MagicMock(), max_records_limit=None, stream_slice=first_1, + record_counter=RecordCounter(), ) ) @@ -4270,6 +4281,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists(): message_repository=MagicMock(), max_records_limit=None, stream_slice=two, + record_counter=RecordCounter(), ) ) @@ -4283,6 +4295,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists(): message_repository=MagicMock(), max_records_limit=None, stream_slice=second_1, + record_counter=RecordCounter(), ) ) @@ -4341,6 +4354,7 @@ def test_duplicate_partition_while_processing(): message_repository=MagicMock(), max_records_limit=None, stream_slice=generated[1], + record_counter=RecordCounter(), ) ) # Now close the initial “1” @@ -4352,6 +4366,7 @@ def test_duplicate_partition_while_processing(): message_repository=MagicMock(), max_records_limit=None, stream_slice=generated[0], + record_counter=RecordCounter(), ) ) diff --git a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py index f9e2779f1..8ed712d1a 100644 --- a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py +++ b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py @@ -4,8 +4,6 @@ from unittest import TestCase from unittest.mock import Mock -# This allows for the global total_record_counter to be reset between tests -import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.declarative.retrievers import Retriever from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader @@ -35,7 +33,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase): def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None: retriever = self._mock_retriever([]) message_repository = Mock(spec=MessageRepository) - partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + partition_factory = DeclarativePartitionFactory( _STREAM_NAME, _SCHEMA_LOADER, retriever, @@ -50,7 +48,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) def test_given_a_mapping_when_read_then_yield_record(self) -> None: retriever = self._mock_retriever([_A_RECORD]) message_repository = Mock(spec=MessageRepository) - partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + partition_factory = DeclarativePartitionFactory( _STREAM_NAME, _SCHEMA_LOADER, retriever, @@ -68,7 +66,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None: def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None: retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE]) message_repository = Mock(spec=MessageRepository) - partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + partition_factory = DeclarativePartitionFactory( _STREAM_NAME, _SCHEMA_LOADER, retriever, @@ -80,8 +78,6 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE) def test_max_records_reached_stops_reading(self) -> None: - declarative_partition_generator.total_record_counter = 0 - expected_records = [ Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"), Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"), @@ -97,7 +93,7 @@ def test_max_records_reached_stops_reading(self) -> None: retriever = self._mock_retriever(mock_records) message_repository = Mock(spec=MessageRepository) - partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + partition_factory = DeclarativePartitionFactory( _STREAM_NAME, _SCHEMA_LOADER, retriever, @@ -113,8 +109,6 @@ def test_max_records_reached_stops_reading(self) -> None: assert actual_records == expected_records def test_max_records_reached_on_previous_partition(self) -> None: - declarative_partition_generator.total_record_counter = 0 - expected_records = [ Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"), Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"), @@ -128,7 +122,7 @@ def test_max_records_reached_on_previous_partition(self) -> None: retriever = self._mock_retriever(mock_records) message_repository = Mock(spec=MessageRepository) - partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + partition_factory = DeclarativePartitionFactory( _STREAM_NAME, _SCHEMA_LOADER, retriever, @@ -151,6 +145,55 @@ def test_max_records_reached_on_previous_partition(self) -> None: # called for the first partition read and not the second retriever.read_records.assert_called_once() + def test_record_counter_isolation_between_different_factories(self) -> None: + """Test that record counters are isolated between different DeclarativePartitionFactory instances.""" + + # Create mock records that exceed the limit + records = [ + Record(data={"id": 1, "name": "Record1"}, stream_name="stream_name"), + Record(data={"id": 2, "name": "Record2"}, stream_name="stream_name"), + Record( + data={"id": 3, "name": "Record3"}, stream_name="stream_name" + ), # Should be blocked by limit + ] + + # Create first factory with record limit of 2 + retriever1 = self._mock_retriever(records) + message_repository1 = Mock(spec=MessageRepository) + factory1 = DeclarativePartitionFactory( + _STREAM_NAME, + _SCHEMA_LOADER, + retriever1, + message_repository1, + max_records_limit=2, + ) + + # First factory should read up to limit (2 records) + partition1 = factory1.create(_A_STREAM_SLICE) + first_factory_records = list(partition1.read()) + assert len(first_factory_records) == 2 + + # Create second factory with same limit - should be independent + retriever2 = self._mock_retriever(records) + message_repository2 = Mock(spec=MessageRepository) + factory2 = DeclarativePartitionFactory( + _STREAM_NAME, + _SCHEMA_LOADER, + retriever2, + message_repository2, + max_records_limit=2, + ) + + # Second factory should also be able to read up to limit (2 records) + # This would fail before the fix because record counter was global + partition2 = factory2.create(_A_STREAM_SLICE) + second_factory_records = list(partition2.read()) + assert len(second_factory_records) == 2 + + # Verify both retrievers were called (confirming isolation) + retriever1.read_records.assert_called_once() + retriever2.read_records.assert_called_once() + @staticmethod def _mock_retriever(read_return_value: List[StreamData]) -> Mock: retriever = Mock(spec=Retriever)