Skip to content

Commit

Permalink
Concurrent CDK: fix state message ordering (#34131)
Browse files Browse the repository at this point in the history
  • Loading branch information
clnoll committed Jan 18, 2024
1 parent 5f35187 commit e3e58cc
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 323 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#
import functools
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional, Protocol, Tuple
from datetime import datetime
from typing import Any, List, Mapping, MutableMapping, Optional, Protocol, Tuple

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository
Expand Down Expand Up @@ -36,6 +37,11 @@ def extract_value(self, record: Record) -> Comparable:


class Cursor(ABC):
@property
@abstractmethod
def state(self) -> MutableMapping[str, Any]:
...

@abstractmethod
def observe(self, record: Record) -> None:
"""
Expand All @@ -52,6 +58,10 @@ def close_partition(self, partition: Partition) -> None:


class NoopCursor(Cursor):
@property
def state(self) -> MutableMapping[str, Any]:
return {}

def observe(self, record: Record) -> None:
pass

Expand All @@ -73,6 +83,7 @@ def __init__(
connector_state_converter: AbstractStreamStateConverter,
cursor_field: CursorField,
slice_boundary_fields: Optional[Tuple[str, str]],
start: Optional[Any],
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
Expand All @@ -82,9 +93,19 @@ def __init__(
self._cursor_field = cursor_field
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
self._start = start
self._most_recent_record: Optional[Record] = None
self._has_closed_at_least_one_slice = False
self.state = stream_state
self.start, self._concurrent_state = self._get_concurrent_state(stream_state)

@property
def state(self) -> MutableMapping[str, Any]:
return self._concurrent_state

def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[datetime, MutableMapping[str, Any]]:
if self._connector_state_converter.is_state_message_compatible(state):
return self._start or self._connector_state_converter.zero_value, self._connector_state_converter.deserialize(state)
return self._connector_state_converter.convert_from_sequential_state(self._cursor_field, state, self._start)

def observe(self, record: Record) -> None:
if self._slice_boundary_fields:
Expand All @@ -102,15 +123,17 @@ def _extract_cursor_value(self, record: Record) -> Any:
def close_partition(self, partition: Partition) -> None:
slice_count_before = len(self.state.get("slices", []))
self._add_slice_to_state(partition)
if slice_count_before < len(self.state["slices"]):
if slice_count_before < len(self.state["slices"]): # only emit if at least one slice has been processed
self._merge_partitions()
self._emit_state_message()
self._has_closed_at_least_one_slice = True

def _add_slice_to_state(self, partition: Partition) -> None:
if self._slice_boundary_fields:
if "slices" not in self.state:
self.state["slices"] = []
raise RuntimeError(
f"The state for stream {self._stream_name} should have at least one slice to delineate the sync start time, but no slices are present. This is unexpected. Please contact Support."
)
self.state["slices"].append(
{
"start": self._extract_from_slice(partition, self._slice_boundary_fields[self._START_BOUNDARY]),
Expand All @@ -126,10 +149,8 @@ def _add_slice_to_state(self, partition: Partition) -> None:

self.state["slices"].append(
{
# TODO: if we migrate stored state to the concurrent state format, we may want this to be the config start date
# instead of zero_value.
"start": self._connector_state_converter.zero_value,
"end": self._extract_cursor_value(self._most_recent_record),
self._connector_state_converter.START_KEY: self.start,
self._connector_state_converter.END_KEY: self._extract_cursor_value(self._most_recent_record),
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, List, MutableMapping, Optional
from typing import TYPE_CHECKING, Any, List, MutableMapping, Tuple

if TYPE_CHECKING:
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
Expand All @@ -18,15 +18,6 @@ class AbstractStreamStateConverter(ABC):
START_KEY = "start"
END_KEY = "end"

def get_concurrent_stream_state(
self, cursor_field: Optional["CursorField"], state: MutableMapping[str, Any]
) -> Optional[MutableMapping[str, Any]]:
if not cursor_field:
return None
if self.is_state_message_compatible(state):
return self.deserialize(state)
return self.convert_from_sequential_state(cursor_field, state)

@abstractmethod
def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
"""
Expand All @@ -40,8 +31,11 @@ def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool:

@abstractmethod
def convert_from_sequential_state(
self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
self,
cursor_field: "CursorField",
stream_state: MutableMapping[str, Any],
start: Any,
) -> Tuple[Any, MutableMapping[str, Any]]:
"""
Convert the state message to the format required by the ConcurrentCursor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import abstractmethod
from datetime import datetime, timedelta
from typing import Any, List, MutableMapping, Optional
from typing import Any, List, MutableMapping, Optional, Tuple

import pendulum
from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
Expand All @@ -16,9 +16,6 @@


class DateTimeStreamStateConverter(AbstractStreamStateConverter):
START_KEY = "start"
END_KEY = "end"

@property
@abstractmethod
def _zero_value(self) -> Any:
Expand Down Expand Up @@ -62,18 +59,20 @@ def merge_intervals(self, intervals: List[MutableMapping[str, datetime]]) -> Lis
for interval in sorted_intervals[1:]:
last_end_time = merged_intervals[-1][self.END_KEY]
current_start_time = interval[self.START_KEY]
if self.compare_intervals(last_end_time, current_start_time):
if self._compare_intervals(last_end_time, current_start_time):
merged_end_time = max(last_end_time, interval[self.END_KEY])
merged_intervals[-1][self.END_KEY] = merged_end_time
else:
merged_intervals.append(interval)

return merged_intervals

def compare_intervals(self, end_time: Any, start_time: Any) -> bool:
def _compare_intervals(self, end_time: Any, start_time: Any) -> bool:
return bool(self.increment(end_time) >= start_time)

def convert_from_sequential_state(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
def convert_from_sequential_state(
self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: datetime
) -> Tuple[datetime, MutableMapping[str, Any]]:
"""
Convert the state message to the format required by the ConcurrentCursor.
Expand All @@ -82,28 +81,35 @@ def convert_from_sequential_state(self, cursor_field: CursorField, stream_state:
"state_type": ConcurrencyCompatibleStateType.date_range.value,
"metadata": { … },
"slices": [
{starts: 0, end: "2021-01-18T21:18:20.000+00:00", finished_processing: true}]
{"start": "2021-01-18T21:18:20.000+00:00", "end": "2021-01-18T21:18:20.000+00:00"},
]
}
"""
sync_start = self._get_sync_start(cursor_field, stream_state, start)
if self.is_state_message_compatible(stream_state):
return stream_state
if cursor_field.cursor_field_key in stream_state:
slices = [
{
# TODO: if we migrate stored state to the concurrent state format, we may want this to be the config start date
# instead of `zero_value`
self.START_KEY: self.zero_value,
self.END_KEY: self.parse_timestamp(stream_state[cursor_field.cursor_field_key]),
},
]
else:
slices = []
return {
return sync_start, stream_state

# Create a slice to represent the records synced during prior syncs.
# The start and end are the same to avoid confusion as to whether the records for this slice
# were actually synced
slices = [{self.START_KEY: sync_start, self.END_KEY: sync_start}]

return sync_start, {
"state_type": ConcurrencyCompatibleStateType.date_range.value,
"slices": slices,
"legacy": stream_state,
}

def _get_sync_start(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[Any]) -> datetime:
sync_start = self.parse_timestamp(start) if start is not None else self.zero_value
prev_sync_low_water_mark = (
self.parse_timestamp(stream_state[cursor_field.cursor_field_key]) if cursor_field.cursor_field_key in stream_state else None
)
if prev_sync_low_water_mark and prev_sync_low_water_mark >= sync_start:
return prev_sync_low_water_mark
else:
return sync_start

def convert_to_sequential_state(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
"""
Convert the state message from the concurrency-compatible format to the stream's original format.
Expand All @@ -113,10 +119,9 @@ def convert_to_sequential_state(self, cursor_field: CursorField, stream_state: M
"""
if self.is_state_message_compatible(stream_state):
legacy_state = stream_state.get("legacy", {})
if slices := stream_state.pop("slices", None):
latest_complete_time = self._get_latest_complete_time(slices)
if latest_complete_time:
legacy_state.update({cursor_field.cursor_field_key: self.output_format(latest_complete_time)})
latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", []))
if latest_complete_time is not None:
legacy_state.update({cursor_field.cursor_field_key: self.output_format(latest_complete_time)})
return legacy_state or {}
else:
return stream_state
Expand All @@ -125,11 +130,12 @@ def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> O
"""
Get the latest time before which all records have been processed.
"""
if slices:
first_interval = self.merge_intervals(slices)[0][self.END_KEY]
return first_interval
else:
return None
if not slices:
raise RuntimeError("Expected at least one slice but there were none. This is unexpected; please contact Support.")

merged_intervals = self.merge_intervals(slices)
first_interval = merged_intervals[0]
return first_interval[self.END_KEY]


class EpochValueConcurrentStreamStateConverter(DateTimeStreamStateConverter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
state_manager = ConnectorStateManager(stream_instance_map={s.name: s for s in self._streams}, state=self._state)
state_converter = StreamFacadeConcurrentConnectorStateConverter()
stream_states = [state_converter.get_concurrent_stream_state(self._cursor_field, state_manager.get_stream_state(stream.name, stream.namespace))
for stream in self._streams]
stream_states = [state_manager.get_stream_state(stream.name, stream.namespace) for stream in self._streams]
return [
StreamFacade.create_from_stream(
stream,
Expand All @@ -69,6 +68,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
state_converter,
self._cursor_field,
self._cursor_boundaries,
None,
)
if self._cursor_field
else NoopCursor(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,50 @@ def _cursor_with_slice_boundary_fields(self) -> ConcurrentCursor:
return ConcurrentCursor(
_A_STREAM_NAME,
_A_STREAM_NAMESPACE,
self._state_converter.get_concurrent_stream_state(CursorField(_A_CURSOR_FIELD_KEY), {}),
{},
self._message_repository,
self._state_manager,
self._state_converter,
CursorField(_A_CURSOR_FIELD_KEY),
_SLICE_BOUNDARY_FIELDS,
None,
)

def _cursor_without_slice_boundary_fields(self) -> ConcurrentCursor:
return ConcurrentCursor(
_A_STREAM_NAME,
_A_STREAM_NAMESPACE,
self._state_converter.get_concurrent_stream_state(CursorField(_A_CURSOR_FIELD_KEY), {}),
{},
self._message_repository,
self._state_manager,
self._state_converter,
CursorField(_A_CURSOR_FIELD_KEY),
None,
None,
)

def test_given_boundary_fields_when_close_partition_then_emit_state(self) -> None:
self._cursor_with_slice_boundary_fields().close_partition(
cursor = self._cursor_with_slice_boundary_fields()
cursor.close_partition(
_partition(
{_LOWER_SLICE_BOUNDARY_FIELD: 12, _UPPER_SLICE_BOUNDARY_FIELD: 30},
)
)

self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value)
self._state_manager.update_state_for_stream.assert_called_once_with(
_A_STREAM_NAME,
_A_STREAM_NAMESPACE,
{_A_CURSOR_FIELD_KEY: 0}, # State message is updated to the legacy format before being emitted
)

def test_given_boundary_fields_when_close_partition_then_emit_updated_state(self) -> None:
self._cursor_with_slice_boundary_fields().close_partition(
_partition(
{_LOWER_SLICE_BOUNDARY_FIELD: 0, _UPPER_SLICE_BOUNDARY_FIELD: 30},
)
)

self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value)
self._state_manager.update_state_for_stream.assert_called_once_with(
_A_STREAM_NAME,
Expand Down
Loading

0 comments on commit e3e58cc

Please sign in to comment.