Skip to content

Commit

Permalink
šŸ› [airbyte-cdk] Fix bug where substreams depending on an RFR parent sā€¦
Browse files Browse the repository at this point in the history
ā€¦tream don't paginate or use existing state (#40671)
  • Loading branch information
brianjlai committed Jul 11, 2024
1 parent 372aea7 commit 9e23b3f
Show file tree
Hide file tree
Showing 7 changed files with 1,011 additions and 266 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Union

import dpath
from airbyte_cdk.models import AirbyteMessage, SyncMode, Type
from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException

if TYPE_CHECKING:
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
Expand Down Expand Up @@ -131,40 +133,70 @@ def stream_slices(self) -> Iterable[StreamSlice]:
parent_field = parent_stream_config.parent_key.eval(self.config) # type: ignore # parent_key is always casted to an interpolated string
partition_field = parent_stream_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string
incremental_dependency = parent_stream_config.incremental_dependency
for parent_stream_slice in parent_stream.stream_slices(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_state=None
):
parent_partition = parent_stream_slice.partition if parent_stream_slice else {}

# we need to read all records for slice to update the parent stream cursor
stream_slices_for_parent = []

# only stream_slice param is used in the declarative stream, stream state is set in PerPartitionCursor set_initial_state
for parent_record in parent_stream.read_records(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_slice=parent_stream_slice, stream_state=None
):
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
if parent_record.type == Type.RECORD:
parent_record = parent_record.record.data
else:
continue
elif isinstance(parent_record, Record):
parent_record = parent_record.data
try:
partition_value = dpath.get(parent_record, parent_field)
except KeyError:
pass

stream_slices_for_parent = []
previous_associated_slice = None

# read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does
# not support either substreams or RFR, but something that needs to be considered once we do
for parent_record in parent_stream.read_only_records():
parent_partition = None
parent_associated_slice = None
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
self.logger.warning(
f"Parent stream {parent_stream.name} returns records of type AirbyteMessage. This SubstreamPartitionRouter is not able to checkpoint incremental parent state."
)
if parent_record.type == MessageType.RECORD:
parent_record = parent_record.record.data
else:
stream_slices_for_parent.append(
StreamSlice(partition={partition_field: partition_value, "parent_slice": parent_partition}, cursor_slice={})
continue
elif isinstance(parent_record, Record):
parent_partition = parent_record.associated_slice.partition if parent_record.associated_slice else {}
parent_associated_slice = parent_record.associated_slice
parent_record = parent_record.data
elif not isinstance(parent_record, Mapping):
# The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid
raise AirbyteTracedException(message=f"Parent stream returned records as invalid type {type(parent_record)}")
try:
partition_value = dpath.get(parent_record, parent_field)
except KeyError:
pass
else:
if incremental_dependency:
if previous_associated_slice is None:
previous_associated_slice = parent_associated_slice
elif previous_associated_slice != parent_associated_slice:
# Update the parent state, as parent stream read all record for current slice and state
# is already updated.
#
# When the associated slice of the current record of the parent stream changes, this
# indicates the parent stream has finished processing the current slice and has moved onto
# the next. When this happens, we should update the partition router's current state and
# flush the previous set of collected records and start a new set
#
# Note: One tricky aspect to take note of here is that parent_stream.state will actually
# fetch state of the stream of the previous record's slice NOT the current record's slice.
# This is because in the retriever, we only update stream state after yielding all the
# records. And since we are in the middle of the current slice, parent_stream.state is
# still set to the previous state.
self._parent_state[parent_stream.name] = parent_stream.state
yield from stream_slices_for_parent

# Reset stream_slices_for_parent after we've flushed parent records for the previous parent slice
stream_slices_for_parent = []
previous_associated_slice = parent_associated_slice
stream_slices_for_parent.append(
StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition or {}}, cursor_slice={}
)
)

# update the parent state, as parent stream read all record for current slice and state is already updated
if incremental_dependency:
self._parent_state[parent_stream.name] = parent_stream.state
# A final parent state update and yield of records is needed, so we don't skip records for the final parent slice
if incremental_dependency:
self._parent_state[parent_stream.name] = parent_stream.state

yield from stream_slices_for_parent
yield from stream_slices_for_parent

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand Down Expand Up @@ -215,3 +247,7 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
}
"""
return self._parent_state

@property
def logger(self) -> logging.Logger:
return logging.getLogger("airbyte.SubstreamPartitionRouter")
37 changes: 32 additions & 5 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union

import airbyte_cdk.sources.utils.casing as casing
from airbyte_cdk.models import AirbyteMessage, AirbyteStream, ConfiguredAirbyteStream, SyncMode
from airbyte_cdk.models import AirbyteMessage, AirbyteStream, ConfiguredAirbyteStream, DestinationSyncMode, SyncMode
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.streams.checkpoint import (
CheckpointMode,
Expand All @@ -24,7 +24,7 @@

# list of all possible HTTP methods which can be used for sending of request bodies
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, ResourceSchemaLoader
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from deprecated import deprecated

Expand Down Expand Up @@ -156,6 +156,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
except AttributeError:
pass

should_checkpoint = bool(state_manager)
checkpoint_reader = self._get_checkpoint_reader(
logger=logger, cursor_field=cursor_field, sync_mode=sync_mode, stream_state=stream_state
)
Expand Down Expand Up @@ -193,25 +194,51 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o

checkpoint_interval = self.state_checkpoint_interval
checkpoint = checkpoint_reader.get_checkpoint()
if checkpoint_interval and record_counter % checkpoint_interval == 0 and checkpoint is not None:
if should_checkpoint and checkpoint_interval and record_counter % checkpoint_interval == 0 and checkpoint is not None:
airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager)
yield airbyte_state_message

if internal_config.is_limit_reached(record_counter):
break
self._observe_state(checkpoint_reader)
checkpoint_state = checkpoint_reader.get_checkpoint()
if checkpoint_state is not None:
if should_checkpoint and checkpoint_state is not None:
airbyte_state_message = self._checkpoint_state(checkpoint_state, state_manager=state_manager)
yield airbyte_state_message

next_slice = checkpoint_reader.next()

checkpoint = checkpoint_reader.get_checkpoint()
if checkpoint is not None:
if should_checkpoint and checkpoint is not None:
airbyte_state_message = self._checkpoint_state(checkpoint, state_manager=state_manager)
yield airbyte_state_message

def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterable[StreamData]:
"""
Helper method that performs a read on a stream with an optional state and emits records. If the parent stream supports
incremental, this operation does not update the stream's internal state (if it uses the modern state setter/getter)
or emit state messages.
"""

configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(
name=self.name,
json_schema={},
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
),
sync_mode=SyncMode.incremental if state else SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.append,
)

yield from self.read(
configured_stream=configured_stream,
logger=self.logger,
slice_logger=DebugSliceLogger(),
stream_state=dict(state) if state else {}, # read() expects MutableMapping instead of Mapping which is used more often
state_manager=None,
internal_config=InternalConfig(),
)

@abstractmethod
def read_records(
self,
Expand Down
29 changes: 15 additions & 14 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from urllib.parse import urljoin

import requests
from airbyte_cdk.models import FailureType, SyncMode
from airbyte_cdk.models import AirbyteMessage, FailureType, SyncMode
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.call_rate import APIBudget
Expand All @@ -18,6 +19,7 @@
from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorHandler, HttpStatusErrorHandler
from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction
from airbyte_cdk.sources.streams.http.http_client import HttpClient
from airbyte_cdk.sources.types import Record
from airbyte_cdk.sources.utils.types import JsonType
from deprecated import deprecated
from requests.auth import AuthBase
Expand Down Expand Up @@ -380,19 +382,18 @@ def __init__(self, parent: HttpStream, **kwargs: Any):
def stream_slices(
self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
parent_stream_slices = self.parent.stream_slices(
sync_mode=SyncMode.full_refresh, cursor_field=cursor_field, stream_state=stream_state
)

# iterate over all parent stream_slices
for stream_slice in parent_stream_slices:
parent_records = self.parent.read_records(
sync_mode=SyncMode.full_refresh, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state
)

# iterate over all parent records with current stream_slice
for record in parent_records:
yield {"parent": record}
# read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does
# not support either substreams or RFR, but something that needs to be considered once we do
for parent_record in self.parent.read_only_records(stream_state):
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
if parent_record.type == MessageType.RECORD:
parent_record = parent_record.record.data
else:
continue
elif isinstance(parent_record, Record):
parent_record = parent_record.data
yield {"parent": parent_record}


@deprecated(version="3.0.0", reason="You should set backoff_strategies explicitly in HttpStream.get_backoff_strategy() instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def test_substream_without_input_state():

stream_instance = test_source.streams({})[1]

stream_slice = StreamSlice(partition={"parent_id": "1"},
cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"})
parent_stream_slice = StreamSlice(partition={}, cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"})

# This mocks the resulting records of the Rates stream which acts as the parent stream of the SubstreamPartitionRouter being tested
with patch.object(
SimpleRetriever, "_read_pages", side_effect=[[Record({"id": "1", CURSOR_FIELD: "2022-01-15"}, stream_slice)],
[Record({"id": "2", CURSOR_FIELD: "2022-01-15"}, stream_slice)]]
SimpleRetriever, "_read_pages", side_effect=[[Record({"id": "1", CURSOR_FIELD: "2022-01-15"}, parent_stream_slice)],
[Record({"id": "2", CURSOR_FIELD: "2022-01-15"}, parent_stream_slice)]]
):
slices = list(stream_instance.stream_slices(sync_mode=SYNC_MODE))
assert list(slices) == [
Expand Down
Loading

0 comments on commit 9e23b3f

Please sign in to comment.