Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airbyte_cdk/manifest_server/command_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_read(
"""
Test the read method of the source.
"""

test_read_handler = TestReader(
max_pages_per_slice=page_limit,
max_slices=slice_limit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@
from airbyte_cdk.sources.types import Record, StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


# For Connector Builder test read operations, we track the total number of records
# read for the stream at the global level so that we can stop reading early if we
# exceed the record limit
total_record_counter = 0
# read for the stream so that we can stop reading early if we exceed the record limit.
class RecordCounter:
def __init__(self) -> None:
self.total_record_counter = 0

def increment(self) -> None:
self.total_record_counter += 1

def reset(self) -> None:
self.total_record_counter = 0

def get_total_records(self) -> int:
return self.total_record_counter


class SchemaLoaderCachingDecorator(SchemaLoader):
Expand Down Expand Up @@ -51,6 +62,7 @@ def __init__(
self._retriever = retriever
self._message_repository = message_repository
self._max_records_limit = max_records_limit
self._record_counter = RecordCounter()

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
Expand All @@ -60,6 +72,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
message_repository=self._message_repository,
max_records_limit=self._max_records_limit,
stream_slice=stream_slice,
record_counter=self._record_counter,
)


Expand All @@ -72,6 +85,7 @@ def __init__(
message_repository: MessageRepository,
max_records_limit: Optional[int],
stream_slice: StreamSlice,
record_counter: RecordCounter,
):
self._stream_name = stream_name
self._schema_loader = schema_loader
Expand All @@ -80,17 +94,17 @@ def __init__(
self._max_records_limit = max_records_limit
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
self._record_counter = record_counter

def read(self) -> Iterable[Record]:
if self._max_records_limit is not None:
global total_record_counter
if total_record_counter >= self._max_records_limit:
if self._record_counter.get_total_records() >= self._max_records_limit:
return
for stream_data in self._retriever.read_records(
self._schema_loader.get_json_schema(), self._stream_slice
):
if self._max_records_limit is not None:
if total_record_counter >= self._max_records_limit:
if self._record_counter.get_total_records() >= self._max_records_limit:
break

if isinstance(stream_data, Mapping):
Expand All @@ -108,7 +122,7 @@ def read(self) -> Iterable[Record]:
self._message_repository.emit_message(stream_data)

if self._max_records_limit is not None:
total_record_counter += 1
self._record_counter.increment()

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
DeclarativePartition,
RecordCounter,
)
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
Expand Down Expand Up @@ -3624,6 +3625,7 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=slice,
record_counter=RecordCounter(),
)
)

Expand Down Expand Up @@ -3709,6 +3711,7 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=slice,
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
Expand Down Expand Up @@ -3804,6 +3807,7 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=slice,
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
Expand Down Expand Up @@ -3894,6 +3898,7 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=slice,
record_counter=RecordCounter(),
)
)

Expand Down Expand Up @@ -3968,6 +3973,7 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=slice,
record_counter=RecordCounter(),
)
)
cursor.ensure_at_least_one_state_emitted()
Expand Down Expand Up @@ -4053,6 +4059,7 @@ def test_semaphore_cleanup():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=s,
record_counter=RecordCounter(),
)
)

Expand Down Expand Up @@ -4173,6 +4180,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=first_1,
record_counter=RecordCounter(),
)
)

Expand All @@ -4185,6 +4193,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=two,
record_counter=RecordCounter(),
)
)

Expand All @@ -4197,6 +4206,7 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=second_1,
record_counter=RecordCounter(),
)
)

Expand Down Expand Up @@ -4258,6 +4268,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=first_1,
record_counter=RecordCounter(),
)
)

Expand All @@ -4270,6 +4281,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=two,
record_counter=RecordCounter(),
)
)

Expand All @@ -4283,6 +4295,7 @@ def test_duplicate_partition_after_closing_partition_cursor_exists():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=second_1,
record_counter=RecordCounter(),
)
)

Expand Down Expand Up @@ -4341,6 +4354,7 @@ def test_duplicate_partition_while_processing():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=generated[1],
record_counter=RecordCounter(),
)
)
# Now close the initial “1”
Expand All @@ -4352,6 +4366,7 @@ def test_duplicate_partition_while_processing():
message_repository=MagicMock(),
max_records_limit=None,
stream_slice=generated[0],
record_counter=RecordCounter(),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from unittest import TestCase
from unittest.mock import Mock

# This allows for the global total_record_counter to be reset between tests
import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
Expand Down Expand Up @@ -35,7 +33,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase):
def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None:
retriever = self._mock_retriever([])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -50,7 +48,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self)
def test_given_a_mapping_when_read_then_yield_record(self) -> None:
retriever = self._mock_retriever([_A_RECORD])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -68,7 +66,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None:
def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None:
retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -80,8 +78,6 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N
message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE)

def test_max_records_reached_stops_reading(self) -> None:
declarative_partition_generator.total_record_counter = 0

expected_records = [
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
Expand All @@ -97,7 +93,7 @@ def test_max_records_reached_stops_reading(self) -> None:

retriever = self._mock_retriever(mock_records)
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -113,8 +109,6 @@ def test_max_records_reached_stops_reading(self) -> None:
assert actual_records == expected_records

def test_max_records_reached_on_previous_partition(self) -> None:
declarative_partition_generator.total_record_counter = 0

expected_records = [
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
Expand All @@ -128,7 +122,7 @@ def test_max_records_reached_on_previous_partition(self) -> None:

retriever = self._mock_retriever(mock_records)
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -151,6 +145,55 @@ def test_max_records_reached_on_previous_partition(self) -> None:
# called for the first partition read and not the second
retriever.read_records.assert_called_once()

def test_record_counter_isolation_between_different_factories(self) -> None:
"""Test that record counters are isolated between different DeclarativePartitionFactory instances."""

# Create mock records that exceed the limit
records = [
Record(data={"id": 1, "name": "Record1"}, stream_name="stream_name"),
Record(data={"id": 2, "name": "Record2"}, stream_name="stream_name"),
Record(
data={"id": 3, "name": "Record3"}, stream_name="stream_name"
), # Should be blocked by limit
]

# Create first factory with record limit of 2
retriever1 = self._mock_retriever(records)
message_repository1 = Mock(spec=MessageRepository)
factory1 = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever1,
message_repository1,
max_records_limit=2,
)

# First factory should read up to limit (2 records)
partition1 = factory1.create(_A_STREAM_SLICE)
first_factory_records = list(partition1.read())
assert len(first_factory_records) == 2

# Create second factory with same limit - should be independent
retriever2 = self._mock_retriever(records)
message_repository2 = Mock(spec=MessageRepository)
factory2 = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever2,
message_repository2,
max_records_limit=2,
)

# Second factory should also be able to read up to limit (2 records)
# This would fail before the fix because record counter was global
partition2 = factory2.create(_A_STREAM_SLICE)
second_factory_records = list(partition2.read())
assert len(second_factory_records) == 2

# Verify both retrievers were called (confirming isolation)
retriever1.read_records.assert_called_once()
retriever2.read_records.assert_called_once()

@staticmethod
def _mock_retriever(read_return_value: List[StreamData]) -> Mock:
retriever = Mock(spec=Retriever)
Expand Down
Loading