Skip to content

Commit

Permalink
feat(airbyte-cdk): add json_schema from ConfiguredCatalog to `Str…
Browse files Browse the repository at this point in the history
…eam` (#39522)

Signed-off-by: Artem Inzhyyants <artem.inzhyyants@gmail.com>
  • Loading branch information
artem1205 committed Jun 19, 2024
1 parent 363c4b1 commit f49c805
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 22 deletions.
18 changes: 17 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typing
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union

import airbyte_cdk.sources.utils.casing as casing
from airbyte_cdk.models import AirbyteMessage, AirbyteStream, ConfiguredAirbyteStream, SyncMode
Expand Down Expand Up @@ -105,6 +105,8 @@ class Stream(ABC):
Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol.
"""

_configured_json_schema: Optional[Dict[str, Any]] = None

# Use self.logger in subclasses to log any messages
@property
def logger(self) -> logging.Logger:
Expand Down Expand Up @@ -143,6 +145,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
) -> Iterable[StreamData]:
sync_mode = configured_stream.sync_mode
cursor_field = configured_stream.cursor_field
self.configured_json_schema = configured_stream.stream.json_schema

# WARNING: When performing a read() that uses incoming stream state, we MUST use the self.state that is defined as
# opposed to the incoming stream_state value. Because some connectors like ones using the file-based CDK modify
Expand Down Expand Up @@ -502,3 +505,16 @@ def _checkpoint_state( # type: ignore # ignoring typing for ConnectorStateMana
# to reduce changes right now and this would span concurrent as well
state_manager.update_state_for_stream(self.name, self.namespace, stream_state)
return state_manager.create_state_message(self.name, self.namespace)

@property
def configured_json_schema(self) -> Optional[Dict[str, Any]]:
"""
This property is set from the read method.
:return Optional[Dict]: JSON schema from configured catalog if provided, otherwise None.
"""
return self._configured_json_schema

@configured_json_schema.setter
def configured_json_schema(self, json_schema: Dict[str, Any]) -> None:
self._configured_json_schema = json_schema
88 changes: 67 additions & 21 deletions airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def close_partition(self, partition: Partition) -> None:
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name='__mock_stream', namespace=None),
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(**self._state),
)
),
),
)
)
Expand Down Expand Up @@ -177,7 +177,11 @@ def get_updated_state(current_stream_state: MutableMapping[str, Any], latest_rec
def test_full_refresh_read_a_single_slice_with_debug(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object.
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), sync_mode=SyncMode.full_refresh,destination_sync_mode=DestinationSyncMode.overwrite)
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
records = [
{"id": 1, "partition": 1},
Expand Down Expand Up @@ -211,9 +215,9 @@ def test_full_refresh_read_a_single_slice_with_debug(constructor):
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name='__mock_stream', namespace=None),
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
)
),
),
),
)
Expand All @@ -237,7 +241,11 @@ def test_full_refresh_read_a_single_slice_with_debug(constructor):
def test_full_refresh_read_a_single_slice(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object.
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), sync_mode=SyncMode.full_refresh,destination_sync_mode=DestinationSyncMode.overwrite)
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
Expand All @@ -263,9 +271,9 @@ def test_full_refresh_read_a_single_slice(constructor):
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name='__mock_stream', namespace=None),
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
)
),
),
),
)
Expand All @@ -290,7 +298,11 @@ def test_full_refresh_read_a_single_slice(constructor):
def test_full_refresh_read_two_slices(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}), sync_mode=SyncMode.full_refresh,destination_sync_mode=DestinationSyncMode.overwrite)
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
Expand Down Expand Up @@ -323,9 +335,9 @@ def test_full_refresh_read_two_slices(constructor):
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name='__mock_stream', namespace=None),
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
)
),
),
),
)
Expand All @@ -344,14 +356,10 @@ def test_full_refresh_read_two_slices(constructor):
def test_incremental_read_two_slices():
# This test verifies that a stream running in incremental mode emits state messages correctly
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(
name="mock_stream",
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
json_schema={}
),
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}),
sync_mode=SyncMode.incremental,
cursor_field=["created_at"],
destination_sync_mode=DestinationSyncMode.overwrite
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
Expand All @@ -375,7 +383,7 @@ def test_incremental_read_two_slices():
*records_partition_1,
_create_state_message("__mock_incremental_stream", {"created_at": timestamp}),
*records_partition_2,
_create_state_message("__mock_incremental_stream", {"created_at": timestamp})
_create_state_message("__mock_incremental_stream", {"created_at": timestamp}),
]

actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
Expand All @@ -387,7 +395,11 @@ def test_incremental_read_two_slices():

def test_concurrent_incremental_read_two_slices():
# This test verifies that an incremental concurrent stream manages state correctly for multiple slices syncing concurrently
configured_stream = ConfiguredAirbyteStream(stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}), sync_mode=SyncMode.incremental,destination_sync_mode=DestinationSyncMode.overwrite)
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}),
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
Expand All @@ -413,7 +425,9 @@ def test_concurrent_incremental_read_two_slices():
*records_partition_2,
]

expected_state = _create_state_message("__mock_stream", {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}})
expected_state = _create_state_message(
"__mock_stream", {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}}
)

actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)

Expand All @@ -430,6 +444,38 @@ def test_concurrent_incremental_read_two_slices():
assert actual_state[0] == expected_state


def test_configured_json_schema():
configured_json_schema = {
"$schema": "https://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"id": {"type": ["null", "number"]},
"name": {"type": ["null", "string"]},
},
}
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema=configured_json_schema),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager(stream_instance_map={})

records = [
{"id": 1, "partition": 1},
{"id": 2, "partition": 1},
]

slice_to_partition = {1: records}
stream = _stream(slice_to_partition, slice_logger, logger, message_repository)
assert not stream.configured_json_schema
_read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
assert stream.configured_json_schema == configured_json_schema


def _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config):
records = []
for record in stream.read(configured_stream, logger, slice_logger, {}, state_manager, internal_config):
Expand Down Expand Up @@ -468,6 +514,6 @@ def _create_state_message(stream: str, state: Mapping[str, Any]) -> AirbyteMessa
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name=stream, namespace=None),
stream_state=AirbyteStateBlob(**state),
)
),
),
)

0 comments on commit f49c805

Please sign in to comment.