Skip to content

Commit

Permalink
Connector builder: read input state if it exists (#37495)
Browse files Browse the repository at this point in the history
  • Loading branch information
girarda committed Apr 24, 2024
1 parent 28209fd commit 86ee91e
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import dataclasses
from datetime import datetime
from typing import Any, Mapping
from typing import Any, List, Mapping

from airbyte_cdk.connector_builder.message_grouper import MessageGrouper
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.models import Type
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
Expand Down Expand Up @@ -54,12 +54,12 @@ def create_source(config: Mapping[str, Any], limits: TestReadLimits) -> Manifest


def read_stream(
source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, limits: TestReadLimits
source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, state: List[AirbyteStateMessage], limits: TestReadLimits
) -> AirbyteMessage:
try:
handler = MessageGrouper(limits.max_pages_per_slice, limits.max_slices, limits.max_records)
stream_name = configured_catalog.streams[0].stream.name # The connector builder only supports a single stream
stream_read = handler.get_message_groups(source, config, configured_catalog, limits.max_records)
stream_read = handler.get_message_groups(source, config, configured_catalog, state, limits.max_records)
return AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(data=dataclasses.asdict(stream_read), stream=stream_name, emitted_at=_emitted_at()),
Expand Down
18 changes: 11 additions & 7 deletions airbyte-cdk/python/airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
from airbyte_cdk.connector import BaseConnector
from airbyte_cdk.connector_builder.connector_builder_handler import TestReadLimits, create_source, get_limits, read_stream, resolve_manifest
from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.source import Source
from airbyte_cdk.utils.traced_exception import AirbyteTracedException


def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog]]:
def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]:
# TODO: Add functionality for the `debug` logger.
# Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`.
parsed_args = AirbyteEntrypoint.parse_args(args)
config_path, catalog_path = parsed_args.config, parsed_args.catalog
config_path, catalog_path, state_path = parsed_args.config, parsed_args.catalog, parsed_args.state
if parsed_args.command != "read":
raise ValueError("Only read commands are allowed for Connector Builder requests.")

Expand All @@ -32,38 +33,41 @@ def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str,
command = config["__command"]
if command == "test_read":
catalog = ConfiguredAirbyteCatalog.parse_obj(BaseConnector.read_config(catalog_path))
state = Source.read_state(state_path)
else:
catalog = None
state = []

if "__injected_declarative_manifest" not in config:
raise ValueError(
f"Invalid config: `__injected_declarative_manifest` should be provided at the root of the config but config only has keys {list(config.keys())}"
)

return command, config, catalog
return command, config, catalog, state


def handle_connector_builder_request(
source: ManifestDeclarativeSource,
command: str,
config: Mapping[str, Any],
catalog: Optional[ConfiguredAirbyteCatalog],
state: List[AirbyteStateMessage],
limits: TestReadLimits,
) -> AirbyteMessage:
if command == "resolve_manifest":
return resolve_manifest(source)
elif command == "test_read":
assert catalog is not None, "`test_read` requires a valid `ConfiguredAirbyteCatalog`, got None."
return read_stream(source, config, catalog, limits)
return read_stream(source, config, catalog, state, limits)
else:
raise ValueError(f"Unrecognized command {command}.")


def handle_request(args: List[str]) -> AirbyteMessage:
command, config, catalog = get_config_and_catalog_from_args(args)
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
return handle_connector_builder_request(source, command, config, catalog, limits).json(exclude_unset=True)
return handle_connector_builder_request(source, command, config, catalog, state, limits).json(exclude_unset=True)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AirbyteControlMessage,
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
OrchestratorType,
Expand Down Expand Up @@ -75,6 +76,7 @@ def get_message_groups(
source: DeclarativeSource,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
state: List[AirbyteStateMessage],
record_limit: Optional[int] = None,
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
Expand All @@ -96,7 +98,7 @@ def get_message_groups(
latest_config_update: AirbyteControlMessage = None
auxiliary_requests = []
for message_group in self._get_message_groups(
self._read_stream(source, config, configured_catalog),
self._read_stream(source, config, configured_catalog, state),
schema_inferrer,
datetime_format_inferrer,
record_limit,
Expand Down Expand Up @@ -181,7 +183,7 @@ def _get_message_groups(
and message.type == MessageType.LOG
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX)
):
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=[latest_state_message] if latest_state_message else [])
current_slice_descriptor = self._parse_slice_description(message.log.message)
current_slice_pages = []
at_least_one_page_in_group = False
Expand Down Expand Up @@ -228,7 +230,7 @@ def _get_message_groups(
else:
if current_page_request or current_page_response or current_page_records:
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=latest_state_message)
yield StreamReadSlices(pages=current_slice_pages, slice_descriptor=current_slice_descriptor, state=[latest_state_message] if latest_state_message else [])

@staticmethod
def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessage, json_message: Optional[Dict[str, Any]]) -> bool:
Expand Down Expand Up @@ -279,12 +281,13 @@ def _close_page(
current_page_records.clear()

def _read_stream(
self, source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog
self, source: DeclarativeSource, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog,
state: List[AirbyteStateMessage]
) -> Iterator[AirbyteMessage]:
# the generator can raise an exception
# iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage
try:
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, {})
yield from AirbyteEntrypoint(source).read(source.spec(self.logger), config, configured_catalog, state)
except Exception as e:
error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}"
yield AirbyteTracedException.from_exception(e, message=error_message).as_airbyte_message()
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/airbyte_cdk/connector_builder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StreamReadPages:
class StreamReadSlices:
pages: List[StreamReadPages]
slice_descriptor: Optional[Dict[str, Any]]
state: Optional[Dict[str, Any]] = None
state: Optional[List[Dict[str, Any]]] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
AirbyteStreamState,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
ConnectorSpecification,
DestinationSyncMode,
Level,
StreamDescriptor,
SyncMode,
)
from airbyte_cdk.models import Type
Expand All @@ -50,6 +53,18 @@
_stream_options = {"name": _stream_name, "primary_key": _stream_primary_key, "url_base": _stream_url_base}
_page_size = 2

_A_STATE = [AirbyteStateMessage(
type="STREAM",
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(
name=_stream_name
),
stream_state={
"key": "value"
}
)
)]

MANIFEST = {
"version": "0.30.3",
"definitions": {
Expand Down Expand Up @@ -266,7 +281,7 @@ def test_resolve_manifest(valid_resolve_manifest_config_file):
config["__command"] = command
source = ManifestDeclarativeSource(MANIFEST)
limits = TestReadLimits()
resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), limits)
resolved_manifest = handle_connector_builder_request(source, command, config, create_configured_catalog("dummy_stream"), _A_STATE, limits)

expected_resolved_manifest = {
"type": "DeclarativeSource",
Expand Down Expand Up @@ -455,10 +470,11 @@ def test_read():
),
)
limits = TestReadLimits()
with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read):
with patch("airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", return_value=stream_read) as mock:
output_record = handle_connector_builder_request(
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits
)
mock.assert_called_with(source, config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits.max_records)
output_record.record.emitted_at = 1
assert output_record == expected_airbyte_message

Expand Down Expand Up @@ -492,7 +508,7 @@ def test_config_update():
return_value=refresh_request_response,
):
output = handle_connector_builder_request(
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), TestReadLimits()
source, "test_read", config, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, TestReadLimits()
)
assert output.record.data["latest_config_update"]

Expand Down Expand Up @@ -529,7 +545,7 @@ def check_config_against_spec(self):

source = MockManifestDeclarativeSource()
limits = TestReadLimits()
response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), limits)
response = read_stream(source, TEST_READ_CONFIG, ConfiguredAirbyteCatalog.parse_obj(CONFIGURED_CATALOG), _A_STATE, limits)

expected_stream_read = StreamRead(
logs=[LogMessage("error_message - a stack trace", "ERROR")],
Expand Down Expand Up @@ -716,7 +732,7 @@ def test_read_source(mock_http_stream):

source = create_source(config, limits)

output_data = read_stream(source, config, catalog, limits).record.data
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
slices = output_data["slices"]

assert len(slices) == max_slices
Expand Down Expand Up @@ -761,7 +777,7 @@ def test_read_source_single_page_single_slice(mock_http_stream):

source = create_source(config, limits)

output_data = read_stream(source, config, catalog, limits).record.data
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
slices = output_data["slices"]

assert len(slices) == max_slices
Expand Down Expand Up @@ -817,7 +833,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error
source = create_source(config, limits)

with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
output_data = read_stream(source, config, catalog, limits).record.data
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
if expected_error:
assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
error_message = output_data["logs"][0]
Expand Down Expand Up @@ -875,7 +891,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected
source = create_source(config, limits)

with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False):
output_data = read_stream(source, config, catalog, limits).record.data
output_data = read_stream(source, config, catalog, _A_STATE, limits).record.data
if expected_error:
assert len(output_data["logs"]) > 0, "Expected at least one log message with the expected error"
error_message = output_data["logs"][0]
Expand Down
Loading

0 comments on commit 86ee91e

Please sign in to comment.