Skip to content

Commit

Permalink
馃悰 follow up to #35471: update the cartesian stream slicer (#35865)
Browse files Browse the repository at this point in the history
  • Loading branch information
girarda committed Mar 7, 2024
1 parent 106102c commit 4a808ee
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
Expand Up @@ -44,7 +44,7 @@ def get_request_params(
) -> Mapping[str, Any]:
return dict(
ChainMap(
*[
*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
s.get_request_params(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
for s in self.stream_slicers
]
Expand All @@ -60,7 +60,7 @@ def get_request_headers(
) -> Mapping[str, Any]:
return dict(
ChainMap(
*[
*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
s.get_request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
for s in self.stream_slicers
]
Expand All @@ -76,7 +76,7 @@ def get_request_body_data(
) -> Mapping[str, Any]:
return dict(
ChainMap(
*[
*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
s.get_request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
for s in self.stream_slicers
]
Expand All @@ -89,10 +89,10 @@ def get_request_body_json(
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Optional[Mapping]:
) -> Mapping[str, Any]:
return dict(
ChainMap(
*[
*[ # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
s.get_request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
for s in self.stream_slicers
]
Expand All @@ -101,4 +101,14 @@ def get_request_body_json(

def stream_slices(self) -> Iterable[StreamSlice]:
sub_slices = (s.stream_slices() for s in self.stream_slicers)
return (dict(ChainMap(*a)) for a in itertools.product(*sub_slices))
product = itertools.product(*sub_slices)
for stream_slice_tuple in product:
partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple]))
cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice]
if len(cursor_slices) > 1:
raise ValueError(f"There should only be a single cursor slice. Found {cursor_slices}")
if cursor_slices:
cursor_slice = cursor_slices[0]
else:
cursor_slice = {}
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.declarative.stream_slicers.cartesian_product_stream_slicer import CartesianProductStreamSlicer
from airbyte_cdk.sources.declarative.types import StreamSlice


@pytest.mark.parametrize(
Expand All @@ -17,7 +18,9 @@
(
"test_single_stream_slicer",
[ListPartitionRouter(values=["customer", "store", "subscription"], cursor_field="owner_resource", config={}, parameters={})],
[{"owner_resource": "customer"}, {"owner_resource": "store"}, {"owner_resource": "subscription"}],
[StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "store"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={})],
),
(
"test_two_stream_slicers",
Expand All @@ -26,14 +29,34 @@
ListPartitionRouter(values=["A", "B"], cursor_field="letter", config={}, parameters={}),
],
[
{"owner_resource": "customer", "letter": "A"},
{"owner_resource": "customer", "letter": "B"},
{"owner_resource": "store", "letter": "A"},
{"owner_resource": "store", "letter": "B"},
{"owner_resource": "subscription", "letter": "A"},
{"owner_resource": "subscription", "letter": "B"},
StreamSlice(partition={"owner_resource": "customer", "letter": "A"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "customer", "letter": "B"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "store", "letter": "A"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "store", "letter": "B"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "subscription", "letter": "A"}, cursor_slice={}),
StreamSlice(partition={"owner_resource": "subscription", "letter": "B"}, cursor_slice={}),
],
),
(
"test_singledatetime",
[
DatetimeBasedCursor(
start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}),
end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}),
step="P1D",
cursor_field=InterpolatedString.create("", parameters={}),
datetime_format="%Y-%m-%d",
cursor_granularity="P1D",
config={},
parameters={},
),
],
[
StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}),
StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}),
StreamSlice(partition={}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}),
],
),
(
"test_list_and_datetime",
[
Expand All @@ -50,15 +73,15 @@
),
],
[
{"owner_resource": "customer", "start_time": "2021-01-01", "end_time": "2021-01-01"},
{"owner_resource": "customer", "start_time": "2021-01-02", "end_time": "2021-01-02"},
{"owner_resource": "customer", "start_time": "2021-01-03", "end_time": "2021-01-03"},
{"owner_resource": "store", "start_time": "2021-01-01", "end_time": "2021-01-01"},
{"owner_resource": "store", "start_time": "2021-01-02", "end_time": "2021-01-02"},
{"owner_resource": "store", "start_time": "2021-01-03", "end_time": "2021-01-03"},
{"owner_resource": "subscription", "start_time": "2021-01-01", "end_time": "2021-01-01"},
{"owner_resource": "subscription", "start_time": "2021-01-02", "end_time": "2021-01-02"},
{"owner_resource": "subscription", "start_time": "2021-01-03", "end_time": "2021-01-03"},
StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}),
StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}),
StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}),
StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}),
StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}),
StreamSlice(partition={"owner_resource": "store"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}),
StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-01", "end_time": "2021-01-01"}),
StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-02", "end_time": "2021-01-02"}),
StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={"start_time": "2021-01-03", "end_time": "2021-01-03"}),
],
),
],
Expand All @@ -69,6 +92,34 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices):
assert slices == expected_slices


def test_stream_slices_raises_exception_if_multiple_cursor_slice_components():
stream_slicers = [
DatetimeBasedCursor(
start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}),
end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}),
step="P1D",
cursor_field=InterpolatedString.create("", parameters={}),
datetime_format="%Y-%m-%d",
cursor_granularity="P1D",
config={},
parameters={},
),
DatetimeBasedCursor(
start_datetime=MinMaxDatetime(datetime="2021-01-01", datetime_format="%Y-%m-%d", parameters={}),
end_datetime=MinMaxDatetime(datetime="2021-01-03", datetime_format="%Y-%m-%d", parameters={}),
step="P1D",
cursor_field=InterpolatedString.create("", parameters={}),
datetime_format="%Y-%m-%d",
cursor_granularity="P1D",
config={},
parameters={},
),
]
slicer = CartesianProductStreamSlicer(stream_slicers=stream_slicers, parameters={})
with pytest.raises(ValueError):
list(slicer.stream_slices())


@pytest.mark.parametrize(
"test_name, stream_1_request_option, stream_2_request_option, expected_req_params, expected_headers,expected_body_json, expected_body_data",
[
Expand Down

0 comments on commit 4a808ee

Please sign in to comment.