Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -429,17 +429,6 @@ def _send_log(self, level: Level, message: str) -> None:
)
)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
first_cursor_value = first.get(cursor_field)
second_cursor_value = second.get(cursor_field)
if first_cursor_value and second_cursor_value:
return self.parse_date(first_cursor_value) >= self.parse_date(second_cursor_value)
elif first_cursor_value:
return True
else:
return False

def set_runtime_lookback_window(self, lookback_window_in_seconds: int) -> None:
"""
Updates the lookback window based on a given number of seconds if the new duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,6 @@ def get_request_body_json(
def should_be_synced(self, record: Record) -> bool:
return self._stream_cursor.should_be_synced(self._convert_record_to_cursor_record(record))

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
return self._stream_cursor.is_greater_than_or_equal(
self._convert_record_to_cursor_record(first),
self._convert_record_to_cursor_record(second),
)

@staticmethod
def _convert_record_to_cursor_record(record: Record) -> Record:
return Record(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,6 @@ def should_be_synced(self, record: Record) -> bool:
self._convert_record_to_cursor_record(record)
)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
if not first.associated_slice or not second.associated_slice:
raise ValueError(
f"Both records should have an associated slice but got {first.associated_slice} and {second.associated_slice}"
)
if first.associated_slice.partition != second.associated_slice.partition:
raise ValueError(
f"To compare records, partition should be the same but got {first.associated_slice.partition} and {second.associated_slice.partition}"
)

return self._get_cursor(first).is_greater_than_or_equal(
self._convert_record_to_cursor_record(first),
self._convert_record_to_cursor_record(second),
)

@staticmethod
def _convert_record_to_cursor_record(record: Record) -> Record:
return Record(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,3 @@ def get_request_body_json(

def should_be_synced(self, record: Record) -> bool:
return self._get_active_cursor().should_be_synced(record)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
return self._global_cursor.is_greater_than_or_equal(first, second)
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ def should_be_synced(self, record: Record) -> bool:
"""
return True

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
RFR record don't have ordering to be compared between one another.
"""
return False

def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
# A top-level RFR cursor only manages the state of a single partition
return self._cursor
Expand Down
27 changes: 2 additions & 25 deletions airbyte_cdk/sources/declarative/retrievers/simple_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from airbyte_cdk.sources.declarative.requesters.requester import Requester
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer
from airbyte_cdk.sources.http_logger import format_http_message
from airbyte_cdk.sources.source import ExperimentalClassWarning
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
Expand Down Expand Up @@ -528,35 +527,13 @@ def read_records(
if self.cursor and current_record:
self.cursor.observe(_slice, current_record)

# Latest record read, not necessarily within slice boundaries.
# TODO Remove once all custom components implement `observe` method.
# https://github.com/airbytehq/airbyte-internal-issues/issues/6955
most_recent_record_from_slice = self._get_most_recent_record(
most_recent_record_from_slice, current_record, _slice
)
yield stream_data

if self.cursor:
self.cursor.close_slice(_slice, most_recent_record_from_slice)
self.cursor.close_slice(_slice)
return

def _get_most_recent_record(
self,
current_most_recent: Optional[Record],
current_record: Optional[Record],
stream_slice: StreamSlice,
) -> Optional[Record]:
if self.cursor and current_record:
if not current_most_recent:
return current_record
else:
return (
current_most_recent
if self.cursor.is_greater_than_or_equal(current_most_recent, current_record)
else current_record
)
else:
return None
# FIXME based on the comment above in SimpleRetriever.read_records, it seems like we can tackle https://github.com/airbytehq/airbyte-internal-issues/issues/6955 and remove this

def _extract_record(
self, stream_data: StreamData, stream_slice: StreamSlice
Expand Down
6 changes: 0 additions & 6 deletions airbyte_cdk/sources/streams/checkpoint/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ def should_be_synced(self, record: Record) -> bool:
Evaluating if a record should be synced allows for filtering and stop condition on pagination
"""

@abstractmethod
def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
Evaluating which record is greater in terms of cursor. This is used to avoid having to capture all the records to close a slice
"""

@abstractmethod
def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ def should_be_synced(self, record: Record) -> bool:
"""
return True

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
RFR record don't have ordering to be compared between one another.
"""
return False

def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
# A top-level RFR cursor only manages the state of a single partition
return self._cursor
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ def should_be_synced(self, record: Record) -> bool:
"""
return True

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
RFR record don't have ordering to be compared between one another.
"""
return False

def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
if not stream_slice:
raise ValueError("A partition needs to be provided in order to extract a state")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1205,57 +1205,5 @@ def test_given_record_without_cursor_value_when_should_be_synced_then_return_tru
assert cursor.should_be_synced(Record({"record without cursor value": "any"}, ANY_SLICE))


def test_given_first_greater_than_second_then_return_true():
cursor = DatetimeBasedCursor(
start_datetime=MinMaxDatetime("3000-01-01", parameters={}),
cursor_field="cursor_field",
datetime_format="%Y-%m-%d",
config=config,
parameters={},
)
assert cursor.is_greater_than_or_equal(
Record({"cursor_field": "2023-01-01"}, {}), Record({"cursor_field": "2021-01-01"}, {})
)


def test_given_first_lesser_than_second_then_return_false():
cursor = DatetimeBasedCursor(
start_datetime=MinMaxDatetime("3000-01-01", parameters={}),
cursor_field="cursor_field",
datetime_format="%Y-%m-%d",
config=config,
parameters={},
)
assert not cursor.is_greater_than_or_equal(
Record({"cursor_field": "2021-01-01"}, {}), Record({"cursor_field": "2023-01-01"}, {})
)


def test_given_no_cursor_value_for_second_than_second_then_return_true():
cursor = DatetimeBasedCursor(
start_datetime=MinMaxDatetime("3000-01-01", parameters={}),
cursor_field="cursor_field",
datetime_format="%Y-%m-%d",
config=config,
parameters={},
)
assert cursor.is_greater_than_or_equal(
Record({"cursor_field": "2021-01-01"}, {}), Record({}, {})
)


def test_given_no_cursor_value_for_first_than_second_then_return_false():
cursor = DatetimeBasedCursor(
start_datetime=MinMaxDatetime("3000-01-01", parameters={}),
cursor_field="cursor_field",
datetime_format="%Y-%m-%d",
config=config,
parameters={},
)
assert not cursor.is_greater_than_or_equal(
Record({}, {}), Record({"cursor_field": "2021-01-01"}, {})
)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -344,81 +344,6 @@ def test_given_unknown_partition_when_should_be_synced_then_raise_error():
)


def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_raise_error():
any_cursor_factory = Mock()
any_partition_router = Mock()
cursor = PerPartitionCursor(any_cursor_factory, any_partition_router)
with pytest.raises(ValueError):
cursor.is_greater_than_or_equal(
Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})),
Record({}, StreamSlice(partition={"another slice": "value"}, cursor_slice={})),
)


@pytest.mark.parametrize(
"first_record_slice, second_record_slice",
[
pytest.param(
StreamSlice(partition={"a slice": "value"}, cursor_slice={}),
None,
id="second record does not have a slice",
),
pytest.param(
None,
StreamSlice(partition={"a slice": "value"}, cursor_slice={}),
id="first record does not have a slice",
),
],
)
def test_given_records_without_a_slice_when_is_greater_than_or_equal_then_raise_error(
first_record_slice, second_record_slice
):
any_cursor_factory = Mock()
any_partition_router = Mock()
cursor = PerPartitionCursor(any_cursor_factory, any_partition_router)
with pytest.raises(ValueError):
cursor.is_greater_than_or_equal(
Record({}, first_record_slice), Record({}, second_record_slice)
)


def test_given_slice_is_unknown_when_is_greater_than_or_equal_then_raise_error():
any_cursor_factory = Mock()
any_partition_router = Mock()
cursor = PerPartitionCursor(any_cursor_factory, any_partition_router)
with pytest.raises(ValueError):
cursor.is_greater_than_or_equal(
Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})),
Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})),
)


def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(
mocked_cursor_factory, mocked_partition_router
):
underlying_cursor = (
MockedCursorBuilder()
.with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}])
.build()
)
mocked_cursor_factory.create.side_effect = [underlying_cursor]
stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={})
mocked_partition_router.stream_slices.return_value = [stream_slice]
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)
first_record = Record(
data={"first": "value"}, associated_slice=stream_slice, stream_name="test_stream"
)
second_record = Record(
data={"second": "value"}, associated_slice=stream_slice, stream_name="test_stream"
)
list(cursor.stream_slices()) # generate internal state

result = cursor.is_greater_than_or_equal(first_record, second_record)

assert result == underlying_cursor.is_greater_than_or_equal.return_value
underlying_cursor.is_greater_than_or_equal.assert_called_once_with(first_record, second_record)


@pytest.mark.parametrize(
"stream_slice, expected_output",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,56 +722,6 @@ def test_limit_stream_slices():
assert truncated_slices == _generate_slices(maximum_number_of_slices)


@pytest.mark.parametrize(
"test_name, first_greater_than_second",
[
("test_first_greater_than_second", True),
("test_second_greater_than_first", False),
],
)
def test_when_read_records_then_cursor_close_slice_with_greater_record(
test_name, first_greater_than_second
):
first_record = Record({"first": 1}, StreamSlice(cursor_slice={}, partition={}))
second_record = Record({"second": 2}, StreamSlice(cursor_slice={}, partition={}))
records = [first_record, second_record]
record_selector = MagicMock()
record_selector.select_records.return_value = records
cursor = MagicMock(spec=DeclarativeCursor)
cursor.is_greater_than_or_equal.return_value = first_greater_than_second
paginator = MagicMock()
paginator.get_request_headers.return_value = {}

retriever = SimpleRetriever(
name="stream_name",
primary_key=primary_key,
requester=MagicMock(),
paginator=paginator,
record_selector=record_selector,
stream_slicer=cursor,
cursor=cursor,
parameters={},
config={},
)
stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"})

def retriever_read_pages(_, __, ___):
return retriever._parse_records(
response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={}
)

with patch.object(
SimpleRetriever,
"_read_pages",
return_value=iter([first_record, second_record]),
side_effect=retriever_read_pages,
):
list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
cursor.close_slice.assert_called_once_with(
stream_slice, first_record if first_greater_than_second else second_record
)


def test_given_stream_data_is_not_record_when_read_records_then_update_slice_with_optional_record():
stream_data = [
AirbyteMessage(
Expand Down Expand Up @@ -808,7 +758,7 @@ def retriever_read_pages(_, __, ___):
):
list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
cursor.observe.assert_not_called()
cursor.close_slice.assert_called_once_with(stream_slice, None)
cursor.close_slice.assert_called_once_with(stream_slice)


def test_given_initial_token_is_zero_when_read_records_then_pass_initial_token():
Expand Down
Loading