Skip to content

Commit

Permalink
Concurrent CDK: support partitioned states (#36811)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Apr 9, 2024
1 parent f29f7bb commit bbf69ae
Show file tree
Hide file tree
Showing 21 changed files with 637 additions and 461 deletions.
Expand Up @@ -7,6 +7,7 @@
from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
Expand All @@ -17,7 +18,9 @@
from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel
from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message
from airbyte_protocol.models import StreamDescriptor


class ConcurrentReadProcessor:
Expand Down Expand Up @@ -56,6 +59,7 @@ def __init__(
self._message_repository = message_repository
self._partition_reader = partition_reader
self._streams_done: Set[str] = set()
self._exceptions_per_stream_name: dict[str, List[Exception]] = {}

def on_partition_generation_completed(self, sentinel: PartitionGenerationCompletedSentinel) -> Iterable[AirbyteMessage]:
"""
Expand Down Expand Up @@ -126,14 +130,16 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
yield message
yield from self._message_repository.consume_queue()

def on_exception(self, exception: Exception) -> Iterable[AirbyteMessage]:
def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMessage]:
"""
This method is called when an exception is raised.
1. Stop all running streams
2. Raise the exception
"""
yield from self._stop_streams()
raise exception
self._exceptions_per_stream_name.setdefault(exception.stream_name, []).append(exception.exception)
yield AirbyteTracedException.from_exception(exception).as_airbyte_message(
stream_descriptor=StreamDescriptor(name=exception.stream_name)
)

def start_next_partition_generator(self) -> Optional[AirbyteMessage]:
"""
Expand Down Expand Up @@ -177,13 +183,7 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]:
yield from self._message_repository.consume_queue()
self._logger.info(f"Finished syncing {stream.name}")
self._streams_done.add(stream_name)
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.COMPLETE)

def _stop_streams(self) -> Iterable[AirbyteMessage]:
self._thread_pool_manager.shutdown()
for stream_name in self._streams_to_running_partitions.keys():
stream = self._stream_name_to_instance[stream_name]
if not self._is_stream_done(stream_name):
self._logger.info(f"Marking stream {stream.name} as STOPPED")
self._logger.info(f"Finished syncing {stream.name}")
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.INCOMPLETE)
stream_status = (
AirbyteStreamStatus.INCOMPLETE if self._exceptions_per_stream_name.get(stream_name, []) else AirbyteStreamStatus.COMPLETE
)
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), stream_status)
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.models import AirbyteMessage
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
Expand Down Expand Up @@ -123,11 +124,6 @@ def _consume_from_queue(
concurrent_stream_processor: ConcurrentReadProcessor,
) -> Iterable[AirbyteMessage]:
while airbyte_message_or_record_or_exception := queue.get():
try:
self._threadpool.shutdown_if_exception()
except Exception as exception:
concurrent_stream_processor.on_exception(exception)

yield from self._handle_item(
airbyte_message_or_record_or_exception,
concurrent_stream_processor,
Expand All @@ -142,7 +138,7 @@ def _handle_item(
concurrent_stream_processor: ConcurrentReadProcessor,
) -> Iterable[AirbyteMessage]:
# handle queue item and call the appropriate handler depending on the type of the queue item
if isinstance(queue_item, Exception):
if isinstance(queue_item, StreamThreadException):
yield from concurrent_stream_processor.on_exception(queue_item)
elif isinstance(queue_item, PartitionGenerationCompletedSentinel):
yield from concurrent_stream_processor.on_partition_generation_completed(queue_item)
Expand Down
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

from typing import Any


class StreamThreadException(Exception):
def __init__(self, exception: Exception, stream_name: str):
self._exception = exception
self._stream_name = stream_name

@property
def stream_name(self) -> str:
return self._stream_name

@property
def exception(self) -> Exception:
return self._exception

def __str__(self) -> str:
return f"Exception while syncing stream {self._stream_name}: {self._exception}"

def __eq__(self, other: Any) -> bool:
if isinstance(other, StreamThreadException):
return self._exception == other._exception and self._stream_name == other._stream_name
return False
Expand Up @@ -71,26 +71,26 @@ def _prune_futures(self, futures: List[Future[Any]]) -> None:
)
futures.pop(index)

def shutdown(self) -> None:
def _shutdown(self) -> None:
# Without a way to stop the threads that have already started, this will not stop the Python application. We are fine today with
# this imperfect approach because we only do this in case of `self._most_recently_seen_exception` which we don't expect to happen
self._threadpool.shutdown(wait=False, cancel_futures=True)

def is_done(self) -> bool:
return all([f.done() for f in self._futures])

def shutdown_if_exception(self) -> None:
"""
This method will raise if there is an exception so that the caller can use it.
"""
if self._most_recently_seen_exception:
self._stop_and_raise_exception(self._most_recently_seen_exception)

def check_for_errors_and_shutdown(self) -> None:
"""
Check if any of the futures have an exception, and raise it if so. If all futures are done, shutdown the threadpool.
If the futures are not done, raise an exception.
:return:
"""
self.shutdown_if_exception()
if self._most_recently_seen_exception:
self._logger.exception(
"An unknown exception has occurred while reading concurrently",
exc_info=self._most_recently_seen_exception,
)
self._stop_and_raise_exception(self._most_recently_seen_exception)

exceptions_from_futures = [f for f in [future.exception() for future in self._futures] if f is not None]
if exceptions_from_futures:
Expand All @@ -102,8 +102,8 @@ def check_for_errors_and_shutdown(self) -> None:
exception = RuntimeError(f"Failed reading with futures not done: {futures_not_done}")
self._stop_and_raise_exception(exception)
else:
self.shutdown()
self._shutdown()

def _stop_and_raise_exception(self, exception: BaseException) -> None:
self.shutdown()
self._shutdown()
raise exception
113 changes: 100 additions & 13 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py
Expand Up @@ -3,8 +3,7 @@
#
import functools
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, List, Mapping, MutableMapping, Optional, Protocol, Tuple
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Protocol, Tuple

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository
Expand All @@ -18,19 +17,41 @@ def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
return functools.reduce(lambda a, b: a[b], path, mapping)


class Comparable(Protocol):
class GapType(Protocol):
"""
This is the representation of gaps between two cursor values. Examples:
* if cursor values are datetimes, GapType is timedelta
* if cursor values are integer, GapType will also be integer
"""

pass


class CursorValueType(Protocol):
"""Protocol for annotating comparable types."""

@abstractmethod
def __lt__(self: "Comparable", other: "Comparable") -> bool:
def __lt__(self: "CursorValueType", other: "CursorValueType") -> bool:
pass

@abstractmethod
def __ge__(self: "CursorValueType", other: "CursorValueType") -> bool:
pass

@abstractmethod
def __add__(self: "CursorValueType", other: GapType) -> "CursorValueType":
pass

@abstractmethod
def __sub__(self: "CursorValueType", other: GapType) -> "CursorValueType":
pass


class CursorField:
def __init__(self, cursor_field_key: str) -> None:
self.cursor_field_key = cursor_field_key

def extract_value(self, record: Record) -> Comparable:
def extract_value(self, record: Record) -> CursorValueType:
cursor_value = record.data.get(self.cursor_field_key)
if cursor_value is None:
raise ValueError(f"Could not find cursor field {self.cursor_field_key} in record")
Expand Down Expand Up @@ -118,7 +139,10 @@ def __init__(
connector_state_converter: AbstractStreamStateConverter,
cursor_field: CursorField,
slice_boundary_fields: Optional[Tuple[str, str]],
start: Optional[Any],
start: Optional[CursorValueType],
end_provider: Callable[[], CursorValueType],
lookback_window: Optional[GapType] = None,
slice_range: Optional[GapType] = None,
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
Expand All @@ -129,15 +153,18 @@ def __init__(
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
self._start = start
self._end_provider = end_provider
self._most_recent_record: Optional[Record] = None
self._has_closed_at_least_one_slice = False
self.start, self._concurrent_state = self._get_concurrent_state(stream_state)
self._lookback_window = lookback_window
self._slice_range = slice_range

@property
def state(self) -> MutableMapping[str, Any]:
return self._concurrent_state

def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[datetime, MutableMapping[str, Any]]:
def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[CursorValueType, MutableMapping[str, Any]]:
if self._connector_state_converter.is_state_message_compatible(state):
return self._start or self._connector_state_converter.zero_value, self._connector_state_converter.deserialize(state)
return self._connector_state_converter.convert_from_sequential_state(self._cursor_field, state, self._start)
Expand Down Expand Up @@ -203,23 +230,20 @@ def _emit_state_message(self) -> None:
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
self._connector_state_converter.convert_to_sequential_state(self._cursor_field, self.state),
self._connector_state_converter.convert_to_state_message(self._cursor_field, self.state),
)
# TODO: if we migrate stored state to the concurrent state format
# (aka stop calling self._connector_state_converter.convert_to_sequential_state`), we'll need to cast datetimes to string or
# int before emitting state
state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace)
self._message_repository.emit_message(state_message)

def _merge_partitions(self) -> None:
self.state["slices"] = self._connector_state_converter.merge_intervals(self.state["slices"])

def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType:
try:
_slice = partition.to_slice()
if not _slice:
raise KeyError(f"Could not find key `{key}` in empty slice")
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a Comparable
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType
except KeyError as exception:
raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception

Expand All @@ -229,3 +253,66 @@ def ensure_at_least_one_state_emitted(self) -> None:
called.
"""
self._emit_state_message()

def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
"""
Generating slices based on a few parameters:
* lookback_window: Buffer to remove from END_KEY of the highest slice
* slice_range: Max difference between two slices. If the difference between two slices is greater, multiple slices will be created
* start: `_split_per_slice_range` will clip any value to `self._start which means that:
* if upper is less than self._start, no slices will be generated
* if lower is less than self._start, self._start will be used as the lower boundary (lookback_window will not be considered in that case)
Note that the slices will overlap at their boundaries. We therefore expect to have at least the lower or the upper boundary to be
inclusive in the API that is queried.
"""
self._merge_partitions()

if self._start is not None and self._is_start_before_first_slice():
yield from self._split_per_slice_range(self._start, self.state["slices"][0][self._connector_state_converter.START_KEY])

if len(self.state["slices"]) == 1:
yield from self._split_per_slice_range(
self._calculate_lower_boundary_of_last_slice(self.state["slices"][0][self._connector_state_converter.END_KEY]),
self._end_provider(),
)
elif len(self.state["slices"]) > 1:
for i in range(len(self.state["slices"]) - 1):
yield from self._split_per_slice_range(
self.state["slices"][i][self._connector_state_converter.END_KEY],
self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
)
yield from self._split_per_slice_range(
self._calculate_lower_boundary_of_last_slice(self.state["slices"][-1][self._connector_state_converter.END_KEY]),
self._end_provider(),
)
else:
raise ValueError("Expected at least one slice")

def _is_start_before_first_slice(self) -> bool:
return self._start is not None and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY]

def _calculate_lower_boundary_of_last_slice(self, lower_boundary: CursorValueType) -> CursorValueType:
if self._lookback_window:
return lower_boundary - self._lookback_window
return lower_boundary

def _split_per_slice_range(self, lower: CursorValueType, upper: CursorValueType) -> Iterable[Tuple[CursorValueType, CursorValueType]]:
if lower >= upper:
return

if self._start and upper < self._start:
return

lower = max(lower, self._start) if self._start else lower
if not self._slice_range or lower + self._slice_range >= upper:
yield lower, upper
else:
stop_processing = False
current_lower_boundary = lower
while not stop_processing:
current_upper_boundary = min(current_lower_boundary + self._slice_range, upper)
yield current_lower_boundary, current_upper_boundary
current_lower_boundary = current_upper_boundary
if current_upper_boundary >= upper:
stop_processing = True
Expand Up @@ -5,6 +5,7 @@
from queue import Queue

from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem
Expand Down Expand Up @@ -52,4 +53,5 @@ def generate_partitions(self, stream: AbstractStream) -> None:
self._queue.put(partition)
self._queue.put(PartitionGenerationCompletedSentinel(stream))
except Exception as e:
self._queue.put(e)
self._queue.put(StreamThreadException(e, stream.name))
self._queue.put(PartitionGenerationCompletedSentinel(stream))
Expand Up @@ -3,6 +3,7 @@
#
from queue import Queue

from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem

Expand Down Expand Up @@ -35,4 +36,5 @@ def process_partition(self, partition: Partition) -> None:
self._queue.put(record)
self._queue.put(PartitionCompleteSentinel(partition))
except Exception as e:
self._queue.put(e)
self._queue.put(StreamThreadException(e, partition.stream_name()))
self._queue.put(PartitionCompleteSentinel(partition))
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Union
from typing import Any, Union

from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
Expand All @@ -21,6 +21,11 @@ def __init__(self, partition: Partition):
"""
self.partition = partition

def __eq__(self, other: Any) -> bool:
if isinstance(other, PartitionCompleteSentinel):
return self.partition == other.partition
return False


"""
Typedef representing the items that can be added to the ThreadBasedConcurrentStream
Expand Down

0 comments on commit bbf69ae

Please sign in to comment.