Skip to content

Commit

Permalink
Emit state when no partitions are generated for ccdk (#34605)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Jan 30, 2024
1 parent 28dae9a commit 2c8b47b
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ def on_partition_generation_completed(self, sentinel: PartitionGenerationComplet
"""
stream_name = sentinel.stream.name
self._streams_currently_generating_partitions.remove(sentinel.stream.name)
ret = []
# It is possible for the stream to already be done if no partitions were generated
# If the partition generation process was completed and there are no partitions left to process, the stream is done
if self._is_stream_done(stream_name) or len(self._streams_to_running_partitions[stream_name]) == 0:
ret.append(self._on_stream_is_done(stream_name))
yield from self._on_stream_is_done(stream_name)
if self._stream_instances_to_start_partition_generation:
ret.append(self.start_next_partition_generator())
return ret
yield self.start_next_partition_generator()

def on_partition(self, partition: Partition) -> None:
"""
Expand Down Expand Up @@ -102,7 +100,7 @@ def on_partition_complete_sentinel(self, sentinel: PartitionCompleteSentinel) ->
partitions_running.remove(partition)
# If all partitions were generated and this was the last one, the stream is done
if partition.stream_name() not in self._streams_currently_generating_partitions and len(partitions_running) == 0:
yield self._on_stream_is_done(partition.stream_name())
yield from self._on_stream_is_done(partition.stream_name())
yield from self._message_repository.consume_queue()

def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
Expand Down Expand Up @@ -171,13 +169,15 @@ def is_done(self) -> bool:
def _is_stream_done(self, stream_name: str) -> bool:
return stream_name in self._streams_done

def _on_stream_is_done(self, stream_name: str) -> AirbyteMessage:
def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]:
self._logger.info(f"Read {self._record_counter[stream_name]} records from {stream_name} stream")
self._logger.info(f"Marking stream {stream_name} as STOPPED")
stream = self._stream_name_to_instance[stream_name]
stream.cursor.ensure_at_least_one_state_emitted()
yield from self._message_repository.consume_queue()
self._logger.info(f"Finished syncing {stream.name}")
self._streams_done.add(stream_name)
return stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.COMPLETE)
yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.COMPLETE)

def _stop_streams(self) -> Iterable[AirbyteMessage]:
self._thread_pool_manager.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from airbyte_cdk.models import AirbyteStream
from airbyte_cdk.sources.streams.concurrent.availability_strategy import StreamAvailability
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from deprecated.classic import deprecated

Expand Down Expand Up @@ -81,3 +82,10 @@ def log_stream_sync_configuration(self) -> None:
"""
Logs the stream's configuration for debugging purposes.
"""

@property
@abstractmethod
def cursor(self) -> Cursor:
"""
:return: The cursor associated with this stream.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def create_from_stream(
primary_key=pk,
cursor_field=cursor_field,
logger=logger,
cursor=cursor,
),
stream,
cursor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def close_partition(self, partition: Partition) -> None:
"""
raise NotImplementedError()

@abstractmethod
def ensure_at_least_one_state_emitted(self) -> None:
"""
State messages are emitted when a partition is closed. However, the platform expects at least one state to be emitted per sync per
stream. Hence, if no partitions are generated, this method needs to be called.
"""
raise NotImplementedError()


class NoopCursor(Cursor):
@property
Expand All @@ -68,6 +76,9 @@ def observe(self, record: Record) -> None:
def close_partition(self, partition: Partition) -> None:
pass

def ensure_at_least_one_state_emitted(self) -> None:
pass


class ConcurrentCursor(Cursor):
_START_BOUNDARY = 0
Expand Down Expand Up @@ -179,3 +190,10 @@ def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a Comparable
except KeyError as exception:
raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception

def ensure_at_least_one_state_emitted(self) -> None:
"""
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
called.
"""
self._emit_state_message()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.models import AirbyteStream, SyncMode
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, StreamAvailability
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator

Expand All @@ -23,6 +24,7 @@ def __init__(
primary_key: List[str],
cursor_field: Optional[str],
logger: Logger,
cursor: Optional[Cursor],
namespace: Optional[str] = None,
) -> None:
self._stream_partition_generator = partition_generator
Expand All @@ -32,6 +34,7 @@ def __init__(
self._primary_key = primary_key
self._cursor_field = cursor_field
self._logger = logger
self._cursor = cursor or NoopCursor()
self._namespace = namespace

def generate_partitions(self) -> Iterable[Partition]:
Expand Down Expand Up @@ -77,3 +80,7 @@ def log_stream_sync_configuration(self) -> None:
"cursor_field": self.cursor_field,
},
)

@property
def cursor(self) -> Cursor:
return self._cursor
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
{"data": {"id": "3", "cursor_field": 2}, "stream": "stream1"},
{"data": {"id": "4", "cursor_field": 3}, "stream": "stream1"},
{"stream1": {"cursor_field": 2}},
{"stream1": {"cursor_field": 2}}, # see Cursor.ensure_at_least_one_state_emitted
]
)
.set_log_levels({"ERROR", "WARN", "WARNING", "INFO", "DEBUG"})
Expand Down Expand Up @@ -152,6 +153,7 @@
{"data": {"id": "3", "cursor_field": 2}, "stream": "stream1"},
{"data": {"id": "4", "cursor_field": 3}, "stream": "stream1"},
{"stream1": {"cursor_field": 2}},
{"stream1": {"cursor_field": 2}}, # see Cursor.ensure_at_least_one_state_emitted
]
)
.set_log_levels({"ERROR", "WARN", "WARNING", "INFO", "DEBUG"})
Expand Down Expand Up @@ -239,6 +241,7 @@
{"data": {"id": "3", "cursor_field": 2}, "stream": "stream1"},
{"data": {"id": "4", "cursor_field": 3}, "stream": "stream1"},
{"stream1": {"cursor_field": 2}},
{"stream1": {"cursor_field": 2}}, # see Cursor.ensure_at_least_one_state_emitted
]
)
.set_log_levels({"ERROR", "WARN", "WARNING", "INFO", "DEBUG"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@
{"data": {"id": "3", "cursor_field": 2}, "stream": "stream1"},
{"data": {"id": "4", "cursor_field": 3}, "stream": "stream1"},
{"stream1": {"cursor_field": 2}},
{"stream1": {"cursor_field": 2}}, # see Cursor.ensure_at_least_one_state_emitted
]
)
.set_log_levels({"ERROR", "WARN", "WARNING", "INFO", "DEBUG"})
Expand Down Expand Up @@ -403,6 +404,7 @@
{"data": {"id": "1", "cursor_field": 0}, "stream": "stream1"},
{"data": {"id": "2", "cursor_field": 3}, "stream": "stream1"},
{"stream1": {"cursor_field": 3}},
{"stream1": {"cursor_field": 3}}, # see Cursor.ensure_at_least_one_state_emitted
]
)
.set_log_levels({"ERROR", "WARN", "WARNING", "INFO", "DEBUG"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from airbyte_cdk.sources.message import InMemoryMessageRepository
from airbyte_cdk.sources.streams.concurrent.cursor import NoopCursor
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder
Expand All @@ -29,6 +30,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

_id_only_stream_with_slice_logger = DefaultStream(
Expand All @@ -46,6 +48,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

_id_only_stream_with_primary_key = DefaultStream(
Expand All @@ -63,6 +66,7 @@
primary_key=["id"],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

_id_only_stream_multiple_partitions = DefaultStream(
Expand All @@ -83,6 +87,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

_id_only_stream_multiple_partitions_concurrency_level_two = DefaultStream(
Expand All @@ -103,6 +108,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

_stream_raising_exception = DefaultStream(
Expand All @@ -120,6 +126,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
)

test_concurrent_cdk_single_stream = (
Expand Down Expand Up @@ -246,6 +253,7 @@
primary_key=[],
cursor_field=None,
logger=logging.getLogger("test_logger"),
cursor=NoopCursor(),
),
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import logging
import unittest
from unittest.mock import Mock
from unittest.mock import Mock, call

import freezegun
from airbyte_cdk.models import (
Expand Down Expand Up @@ -32,6 +32,7 @@

_STREAM_NAME = "stream"
_ANOTHER_STREAM_NAME = "stream2"
_ANY_AIRBYTE_MESSAGE = Mock(spec=AirbyteMessage)


class TestConcurrentReadProcessor(unittest.TestCase):
Expand Down Expand Up @@ -110,6 +111,10 @@ def test_handle_partition_done_no_other_streams_to_generate_partitions_for(self)

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_handle_last_stream_partition_done(self):
in_order_validation_mock = Mock()
in_order_validation_mock.attach_mock(self._another_stream, "_another_stream")
in_order_validation_mock.attach_mock(self._message_repository, '_message_repository')
self._message_repository.consume_queue.return_value = iter([_ANY_AIRBYTE_MESSAGE])
stream_instances_to_read_from = [self._another_stream]

handler = ConcurrentReadProcessor(
Expand All @@ -124,9 +129,10 @@ def test_handle_last_stream_partition_done(self):
handler.start_next_partition_generator()

sentinel = PartitionGenerationCompletedSentinel(self._another_stream)
messages = handler.on_partition_generation_completed(sentinel)
messages = list(handler.on_partition_generation_completed(sentinel))

expected_messages = [
_ANY_AIRBYTE_MESSAGE,
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
Expand All @@ -140,6 +146,7 @@ def test_handle_last_stream_partition_done(self):
)
]
assert expected_messages == messages
assert in_order_validation_mock.mock_calls.index(call._another_stream.cursor.ensure_at_least_one_state_emitted) < in_order_validation_mock.mock_calls.index(call._message_repository.consume_queue)

def test_handle_partition(self):
stream_instances_to_read_from = [self._another_stream]
Expand Down Expand Up @@ -236,7 +243,7 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre
)
handler.start_next_partition_generator()
handler.on_partition(self._a_closed_partition)
handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream))
list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream)))

sentinel = PartitionCompleteSentinel(self._a_closed_partition)

Expand Down Expand Up @@ -543,8 +550,8 @@ def test_on_exception_does_not_stop_streams_that_are_already_done(self):

handler.start_next_partition_generator()
handler.on_partition(self._an_open_partition)
handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream))
handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream))
list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream)))
list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._another_stream)))

another_stream = Mock(spec=AbstractStream)
another_stream.name = _STREAM_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from airbyte_cdk.models import AirbyteStream, SyncMode
from airbyte_cdk.sources.streams.concurrent.availability_strategy import STREAM_AVAILABLE
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream


Expand All @@ -28,6 +28,7 @@ def setUp(self):
self._primary_key,
self._cursor_field,
self._logger,
NoopCursor(),
)

def test_get_json_schema(self):
Expand Down Expand Up @@ -88,6 +89,7 @@ def test_as_airbyte_stream_with_primary_key(self):
["id"],
self._cursor_field,
self._logger,
NoopCursor(),
)

expected_airbyte_stream = AirbyteStream(
Expand Down Expand Up @@ -119,6 +121,7 @@ def test_as_airbyte_stream_with_composite_primary_key(self):
["id_a", "id_b"],
self._cursor_field,
self._logger,
NoopCursor(),
)

expected_airbyte_stream = AirbyteStream(
Expand Down Expand Up @@ -150,6 +153,7 @@ def test_as_airbyte_stream_with_a_cursor(self):
self._primary_key,
"date",
self._logger,
NoopCursor(),
)

expected_airbyte_stream = AirbyteStream(
Expand All @@ -174,6 +178,7 @@ def test_as_airbyte_stream_with_namespace(self):
self._primary_key,
self._cursor_field,
self._logger,
NoopCursor(),
namespace="test",
)
expected_airbyte_stream = AirbyteStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_given_no_partitions_when_generate_partitions_then_do_not_wait(self, moc

assert mocked_sleep.call_count == 0

def test_given_partitions_when_generate_partitions_then_only_push_sentinel(self):
def test_given_no_partitions_when_generate_partitions_then_only_push_sentinel(self):
self._thread_pool_manager.prune_to_validate_has_reached_futures_limit.return_value = True
stream = self._a_stream([])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.availability_strategy import StreamAvailability, StreamAvailable, StreamUnavailable
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_protocol.models import AirbyteStream
Expand Down Expand Up @@ -72,6 +73,10 @@ def as_airbyte_stream(self) -> AirbyteStream:
def log_stream_sync_configuration(self) -> None:
raise NotImplementedError

@property
def cursor(self) -> Cursor:
return NoopCursor()


class _MockPartition(Partition):
def __init__(self, name: str):
Expand Down

0 comments on commit 2c8b47b

Please sign in to comment.