Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tolik0 committed May 23, 2024
1 parent 24f9990 commit 84b40a0
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Any, Iterable, List, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def stream_slices(self) -> Iterable[StreamSlice]:
for parent_stream_slice in parent_stream.stream_slices(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_state=None
):
empty_parent_slice = True
parent_partition = parent_stream_slice.partition if parent_stream_slice else {}

# we need to read all records for slice to update the parent stream cursor
stream_slices_for_parent = []
for parent_record in parent_stream.read_records(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_slice=parent_stream_slice, stream_state=None
):
Expand All @@ -153,16 +154,15 @@ def stream_slices(self) -> Iterable[StreamSlice]:
except KeyError:
pass
else:
empty_parent_slice = False
yield StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition}, cursor_slice={}
stream_slices_for_parent.append(
StreamSlice(partition={partition_field: partition_value, "parent_slice": parent_partition}, cursor_slice={})
)

# update the parent state, as parent stream read all record for current slice and state is already updated
if incremental_dependency:
self._parent_state[parent_stream.name] = parent_stream.state

# If the parent slice contains no records,
if empty_parent_slice:
yield from []
yield from stream_slices_for_parent

def set_parent_state(self, stream_state: Optional[StreamState]) -> None:
"""
Expand Down Expand Up @@ -190,4 +190,5 @@ def get_parent_state(self) -> StreamState:
Returns:
StreamState: The current state of the parent streams.
"""
parent_stream_name = self.parent_stream_configs[0].stream.name if self.parent_stream_configs else None
return self._parent_state
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record

PARTITION = {
Expand Down Expand Up @@ -105,7 +105,7 @@ def build(self):

@pytest.fixture()
def mocked_partition_router():
return Mock(spec=StreamSlicer)
return Mock(spec=PartitionRouter)


@pytest.fixture()
Expand Down Expand Up @@ -157,7 +157,10 @@ def test_given_partition_associated_with_state_when_stream_slices_then_do_not_re

def test_given_multiple_partitions_then_each_have_their_state(mocked_cursor_factory, mocked_partition_router):
first_partition = {"first_partition_key": "first_partition_value"}
mocked_partition_router.stream_slices.return_value = [StreamSlice(partition=first_partition, cursor_slice={}), StreamSlice(partition={"second_partition_key": "second_partition_value"}, cursor_slice={})]
mocked_partition_router.stream_slices.return_value = [
StreamSlice(partition=first_partition, cursor_slice={}),
StreamSlice(partition={"second_partition_key": "second_partition_value"}, cursor_slice={}),
]
first_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build()
second_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "second slice cursor value"}]).build()
mocked_cursor_factory.create.side_effect = [first_cursor, second_cursor]
Expand All @@ -183,7 +186,14 @@ def test_given_stream_slices_when_get_stream_state_then_return_updated_state(moc
MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "first slice cursor value"}).build(),
MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "second slice cursor value"}).build(),
]
mocked_partition_router.stream_slices.return_value = [StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), StreamSlice(partition={"partition key": "second partition"}, cursor_slice={})]
mocked_partition_router.stream_slices.return_value = [
StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}),
StreamSlice(partition={"partition key": "second partition"}, cursor_slice={}),
]

# Mock the get_parent_state method to return the parent state
mocked_partition_router.get_parent_state.return_value = {}

cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)
list(cursor.stream_slices())
assert cursor.get_stream_state() == {
Expand Down Expand Up @@ -265,17 +275,14 @@ def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_r
[
pytest.param(StreamSlice(partition={"a slice": "value"}, cursor_slice={}), None, id="second record does not have a slice"),
pytest.param(None, StreamSlice(partition={"a slice": "value"}, cursor_slice={}), id="first record does not have a slice"),
]
],
)
def test_given_records_without_a_slice_when_is_greater_than_or_equal_then_raise_error(first_record_slice, second_record_slice):
any_cursor_factory = Mock()
any_partition_router = Mock()
cursor = PerPartitionCursor(any_cursor_factory, any_partition_router)
with pytest.raises(ValueError):
cursor.is_greater_than_or_equal(
Record({}, first_record_slice),
Record({}, second_record_slice)
)
cursor.is_greater_than_or_equal(Record({}, first_record_slice), Record({}, second_record_slice))


def test_given_slice_is_unknown_when_is_greater_than_or_equal_then_raise_error():
Expand Down Expand Up @@ -308,9 +315,13 @@ def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(mo
@pytest.mark.parametrize(
"stream_slice, expected_output",
[
pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"),
pytest.param(
StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}),
{"cursor": "params", "router": "params"},
id="first partition",
),
pytest.param(None, None, id="first partition"),
]
],
)
def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output):
underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build()
Expand All @@ -323,7 +334,9 @@ def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stre
cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]})
params = cursor.get_request_params(stream_slice=stream_slice)
assert params == expected_output
mocked_partition_router.get_request_params.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None)
mocked_partition_router.get_request_params.assert_called_once_with(
stream_state=None, stream_slice=stream_slice, next_page_token=None
)
underlying_cursor.get_request_params.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None)
else:
with pytest.raises(ValueError):
Expand All @@ -333,9 +346,13 @@ def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stre
@pytest.mark.parametrize(
"stream_slice, expected_output",
[
pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"),
pytest.param(
StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}),
{"cursor": "params", "router": "params"},
id="first partition",
),
pytest.param(None, None, id="first partition"),
]
],
)
def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output):
underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build()
Expand All @@ -348,7 +365,9 @@ def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, str
cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]})
params = cursor.get_request_headers(stream_slice=stream_slice)
assert params == expected_output
mocked_partition_router.get_request_headers.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None)
mocked_partition_router.get_request_headers.assert_called_once_with(
stream_state=None, stream_slice=stream_slice, next_page_token=None
)
underlying_cursor.get_request_headers.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None)
else:
with pytest.raises(ValueError):
Expand All @@ -358,9 +377,13 @@ def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, str
@pytest.mark.parametrize(
"stream_slice, expected_output",
[
pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"),
pytest.param(
StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}),
{"cursor": "params", "router": "params"},
id="first partition",
),
pytest.param(None, None, id="first partition"),
]
],
)
def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output):
underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build()
Expand All @@ -373,7 +396,9 @@ def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, s
cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]})
params = cursor.get_request_body_data(stream_slice=stream_slice)
assert params == expected_output
mocked_partition_router.get_request_body_data.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None)
mocked_partition_router.get_request_body_data.assert_called_once_with(
stream_state=None, stream_slice=stream_slice, next_page_token=None
)
underlying_cursor.get_request_body_data.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None)
else:
with pytest.raises(ValueError):
Expand All @@ -383,9 +408,13 @@ def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, s
@pytest.mark.parametrize(
"stream_slice, expected_output",
[
pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"),
pytest.param(
StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}),
{"cursor": "params", "router": "params"},
id="first partition",
),
pytest.param(None, None, id="first partition"),
]
],
)
def test_get_request_body_json(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output):
underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build()
Expand All @@ -398,8 +427,91 @@ def test_get_request_body_json(mocked_cursor_factory, mocked_partition_router, s
cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]})
params = cursor.get_request_body_json(stream_slice=stream_slice)
assert params == expected_output
mocked_partition_router.get_request_body_json.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None)
mocked_partition_router.get_request_body_json.assert_called_once_with(
stream_state=None, stream_slice=stream_slice, next_page_token=None
)
underlying_cursor.get_request_body_json.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None)
else:
with pytest.raises(ValueError):
cursor.get_request_body_json(stream_slice=stream_slice)


def test_parent_state_is_set_for_per_partition_cursor(mocked_cursor_factory, mocked_partition_router):
# Define the parent state to be used in the test
parent_state = {"parent_cursor": "parent_state_value"}

# Mock the partition router to return a stream slice
partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={})
mocked_partition_router.stream_slices.return_value = [partition]

# Mock the cursor factory to create cursors with specific states
mocked_cursor_factory.create.side_effect = [
MockedCursorBuilder()
.with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}])
.with_stream_state(CURSOR_STATE)
.build(),
]

# Mock the get_parent_state method to return the parent state
mocked_partition_router.get_parent_state.return_value = parent_state

# Initialize the PerPartitionCursor with the mocked cursor factory and partition router
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)

# Set the initial state, including the parent state
initial_state = {
"states": [{"partition": partition.partition, "cursor": CURSOR_STATE}],
"parent_state": parent_state,
}
cursor.set_initial_state(initial_state)

# Verify that the parent state has been set correctly
assert cursor.get_stream_state()["parent_state"] == parent_state

# Verify that set_parent_state was called on the partition router with the initial state
mocked_partition_router.set_parent_state.assert_called_once_with(initial_state)


def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_partition_router):
# Define the parent state to be used in the test
parent_state = {"parent_cursor": "parent_state_value"}

# Define the expected cursor states
cursor_state_1 = {CURSOR_STATE_KEY: "first slice cursor value"}
cursor_state_2 = {CURSOR_STATE_KEY: "second slice cursor value"}

# Mock the partition router to return stream slices
partition_1 = {"partition_field_1": "a value", "partition_field_2": "another value"}
partition_2 = {"partition_field_1": "another value", "partition_field_2": "yet another value"}
mocked_partition_router.stream_slices.return_value = [
StreamSlice(partition=partition_1, cursor_slice={}),
StreamSlice(partition=partition_2, cursor_slice={}),
]

# Mock the cursor factory to create cursors with specific states
mocked_cursor_factory.create.side_effect = [
MockedCursorBuilder().with_stream_state(cursor_state_1).build(),
MockedCursorBuilder().with_stream_state(cursor_state_2).build(),
]

# Mock the get_parent_state method to return the parent state
mocked_partition_router.get_parent_state.return_value = parent_state

# Initialize the PerPartitionCursor with the mocked cursor factory and partition router
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)

# Simulate reading the records to initialize the internal state
list(cursor.stream_slices())

# Get the combined stream state
stream_state = cursor.get_stream_state()

# Verify that the combined state includes both partition states and the parent state
expected_state = {
"states": [
{"partition": partition_1, "cursor": cursor_state_1},
{"partition": partition_2, "cursor": cursor_state_2},
],
"parent_state": parent_state,
}
assert stream_state == expected_state
Loading

0 comments on commit 84b40a0

Please sign in to comment.