diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index f304bfb21..7da594155 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -276,7 +276,7 @@ def __init__( def read(self) -> Iterable[Record]: """ Read messages from the stream. - If the StreamData is a Mapping, it will be converted to a Record. + If the StreamData is a Mapping or an AirbyteMessage of type RECORD, it will be converted to a Record. Otherwise, the message will be emitted on the message repository. """ try: @@ -292,6 +292,8 @@ def read(self) -> Iterable[Record]: stream_slice=copy.deepcopy(self._slice), stream_state=self._state, ): + # Noting we'll also need to support FileTransferRecordMessage if we want to support file-based connectors in this facade + # For now, file-based connectors have their own stream facade if isinstance(record_data, Mapping): data_to_return = dict(record_data) self._stream.transformer.transform( @@ -302,6 +304,12 @@ def read(self) -> Iterable[Record]: stream_name=self.stream_name(), associated_slice=self._slice, # type: ignore [arg-type] ) + elif isinstance(record_data, AirbyteMessage) and record_data.record is not None: + yield Record( + data=record_data.record.data or {}, + stream_name=self.stream_name(), + associated_slice=self._slice, # type: ignore [arg-type] + ) else: self._message_repository.emit_message(record_data) except Exception as e: diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index f809fa38d..66f48a9e0 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -7,7 +7,14 @@ import pytest -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStream, + Level, + SyncMode, +) from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.message import InMemoryMessageRepository from airbyte_cdk.sources.streams.concurrent.adapters import ( @@ -132,6 +139,61 @@ def test_stream_partition(transformer, expected_records): assert messages == [a_log_message] +@pytest.mark.parametrize( + "transformer, expected_records", + [ + pytest.param( + TypeTransformer(TransformConfig.NoTransform), + [Record({"data": "1"}, None), Record({"data": "2"}, None)], + id="test_no_transform", + ), + ], +) +def test_stream_partition_read_airbyte_message(transformer, expected_records): + stream = Mock() + stream.name = _STREAM_NAME + stream.get_json_schema.return_value = { + "type": "object", + "properties": {"data": {"type": ["integer"]}}, + } + stream.transformer = transformer + message_repository = InMemoryMessageRepository() + _slice = None + sync_mode = SyncMode.full_refresh + cursor_field = None + state = None + partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state) + + a_log_message = AirbyteMessage( + type=MessageType.LOG, + log=AirbyteLogMessage( + level=Level.INFO, + message='slice:{"partition": 1}', + ), + ) + for record in expected_records: + record.partition = partition + + stream_data = [ + a_log_message, + AirbyteMessage( + type=MessageType.RECORD, + record=AirbyteRecordMessage(stream=stream.name, data={"data": "1"}, emitted_at=1), + ), + AirbyteMessage( + type=MessageType.RECORD, + record=AirbyteRecordMessage(stream=stream.name, data={"data": "2"}, emitted_at=2), + ), + ] + stream.read_records.return_value = stream_data + + records = list(partition.read()) + messages = list(message_repository.consume_queue()) + + assert records == expected_records + assert messages == [a_log_message] + + @pytest.mark.parametrize( "exception_type, expected_display_message", [