diff --git a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py index 0f8bf716cc10dd..d0d56b6f903841 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py @@ -15,6 +15,7 @@ ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, Status, + StreamDescriptor, SyncMode, ) from airbyte_cdk.models import Type as MessageType @@ -27,6 +28,7 @@ from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger, SliceLogger +from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.event_timing import create_timer from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -133,11 +135,16 @@ def read( logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.COMPLETE) except AirbyteTracedException as e: + logger.exception(f"Encountered an exception while reading stream {configured_stream.stream.name}") + logger.info(f"Marking stream {configured_stream.stream.name} as STOPPED") yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) - if self.continue_sync_on_stream_failure: - stream_name_to_exception[stream_instance.name] = e - else: - raise e + yield e.as_sanitized_airbyte_message(stream_descriptor=StreamDescriptor(name=configured_stream.stream.name)) + stream_name_to_exception[stream_instance.name] = e + if self.stop_sync_on_stream_failure: + logger.info( + f"Stopping sync on error from stream {configured_stream.stream.name} because {self.name} does not support continuing syncs on error." + ) + break except Exception as e: yield from self._emit_queued_messages() logger.exception(f"Encountered an exception while reading stream {configured_stream.stream.name}") @@ -145,15 +152,27 @@ def read( yield stream_status_as_airbyte_message(configured_stream.stream, AirbyteStreamStatus.INCOMPLETE) display_message = stream_instance.get_error_display_message(e) if display_message: - raise AirbyteTracedException.from_exception(e, message=display_message) from e - raise e + traced_exception = AirbyteTracedException.from_exception(e, message=display_message) + else: + traced_exception = AirbyteTracedException.from_exception(e) + yield traced_exception.as_sanitized_airbyte_message( + stream_descriptor=StreamDescriptor(name=configured_stream.stream.name) + ) + stream_name_to_exception[stream_instance.name] = traced_exception + if self.stop_sync_on_stream_failure: + logger.info(f"{self.name} does not support continuing syncs on error from stream {configured_stream.stream.name}") + break finally: timer.finish_event() logger.info(f"Finished syncing {configured_stream.stream.name}") logger.info(timer.report()) - if self.continue_sync_on_stream_failure and len(stream_name_to_exception) > 0: - raise AirbyteTracedException(message=self._generate_failed_streams_error_message(stream_name_to_exception)) + if len(stream_name_to_exception) > 0: + error_message = self._generate_failed_streams_error_message(stream_name_to_exception) + logger.info(error_message) + # We still raise at least one exception when a stream raises an exception because the platform + # currently relies on a non-zero exit code to determine if a sync attempt has failed + raise AirbyteTracedException(message=error_message) logger.info(f"Finished syncing {self.name}") @property @@ -282,17 +301,17 @@ def message_repository(self) -> Union[None, MessageRepository]: return _default_message_repository @property - def continue_sync_on_stream_failure(self) -> bool: + def stop_sync_on_stream_failure(self) -> bool: """ WARNING: This function is in-development which means it is subject to change. Use at your own risk. - By default, a source should raise an exception and stop the sync when it encounters an error while syncing a stream. This - method can be overridden on a per-source basis so that a source will continue syncing streams other streams even if an - exception is raised for a stream. + By default, when a source encounters an exception while syncing a stream, it will emit an error trace message and then + continue syncing the next stream. This can be overwridden on a per-source basis so that the source will stop the sync + on the first error seen and emit a single error trace message for that stream. """ return False @staticmethod def _generate_failed_streams_error_message(stream_failures: Mapping[str, AirbyteTracedException]) -> str: - failures = ", ".join([f"{stream}: {exception.__repr__()}" for stream, exception in stream_failures.items()]) + failures = ", ".join([f"{stream}: {filter_secrets(exception.__repr__())}" for stream, exception in stream_failures.items()]) return f"During the sync, the following streams did not sync successfully: {failures}" diff --git a/airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py b/airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py index dec09fcf19290f..753296a5dd74df 100644 --- a/airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py +++ b/airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py @@ -13,6 +13,7 @@ AirbyteTraceMessage, FailureType, Status, + StreamDescriptor, TraceType, ) from airbyte_cdk.models import Type as MessageType @@ -43,7 +44,7 @@ def __init__( self._exception = exception super().__init__(internal_message) - def as_airbyte_message(self) -> AirbyteMessage: + def as_airbyte_message(self, stream_descriptor: StreamDescriptor = None) -> AirbyteMessage: """ Builds an AirbyteTraceMessage from the exception """ @@ -60,6 +61,7 @@ def as_airbyte_message(self) -> AirbyteMessage: internal_message=self.internal_message, failure_type=self.failure_type, stack_trace=stack_trace_str, + stream_descriptor=stream_descriptor, ), ) @@ -88,3 +90,16 @@ def from_exception(cls, exc: BaseException, *args, **kwargs) -> "AirbyteTracedEx :param exc: the exception that caused the error """ return cls(internal_message=str(exc), exception=exc, *args, **kwargs) # type: ignore # ignoring because of args and kwargs + + def as_sanitized_airbyte_message(self, stream_descriptor: StreamDescriptor = None) -> AirbyteMessage: + """ + Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body + """ + error_message = self.as_airbyte_message(stream_descriptor=stream_descriptor) + if error_message.trace.error.message: + error_message.trace.error.message = filter_secrets(error_message.trace.error.message) + if error_message.trace.error.internal_message: + error_message.trace.error.internal_message = filter_secrets(error_message.trace.error.internal_message) + if error_message.trace.error.stack_trace: + error_message.trace.error.stack_trace = filter_secrets(error_message.trace.error.stack_trace) + return error_message diff --git a/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py b/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py index 4315f488112d36..f38a9e21b555ac 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py +++ b/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py @@ -13,6 +13,7 @@ from airbyte_cdk.models import ( AirbyteCatalog, AirbyteConnectionStatus, + AirbyteErrorTraceMessage, AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, @@ -27,6 +28,7 @@ ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, DestinationSyncMode, + FailureType, Level, Status, StreamDescriptor, @@ -40,6 +42,7 @@ from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import IncrementalMixin, Stream from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message +from airbyte_cdk.utils.airbyte_secrets_utils import update_secrets from airbyte_cdk.utils.traced_exception import AirbyteTracedException from pytest import fixture @@ -54,12 +57,14 @@ def __init__( per_stream: bool = True, message_repository: MessageRepository = None, exception_on_missing_stream: bool = True, + stop_sync_on_stream_failure: bool = False, ): self._streams = streams self.check_lambda = check_lambda self.per_stream = per_stream self.exception_on_missing_stream = exception_on_missing_stream self._message_repository = message_repository + self._stop_sync_on_stream_failure = stop_sync_on_stream_failure def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: if self.check_lambda: @@ -84,6 +89,12 @@ def message_repository(self): return self._message_repository +class MockSourceWithStopSyncFalseOverride(MockSource): + @property + def stop_sync_on_stream_failure(self) -> bool: + return False + + class StreamNoStateMethod(Stream): name = "managers" primary_key = None @@ -115,8 +126,11 @@ class StreamRaisesException(Stream): name = "lamentations" primary_key = None + def __init__(self, exception_to_raise): + self._exception_to_raise = exception_to_raise + def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]: - raise AirbyteTracedException(message="I was born only to crash like Icarus") + raise self._exception_to_raise MESSAGE_FROM_REPOSITORY = Mock() @@ -291,7 +305,7 @@ def test_read_stream_emits_repository_message_on_error(mocker, message_repositor source = MockSource(streams=[stream], message_repository=message_repository) - with pytest.raises(RuntimeError): + with pytest.raises(AirbyteTracedException): messages = list(source.read(logger, {}, ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]))) assert MESSAGE_FROM_REPOSITORY in messages @@ -306,14 +320,14 @@ def test_read_stream_with_error_gets_display_message(mocker): catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)]) # without get_error_display_message - with pytest.raises(RuntimeError, match="oh no!"): + with pytest.raises(AirbyteTracedException): list(source.read(logger, {}, catalog)) mocker.patch.object(MockStream, "get_error_display_message", return_value="my message") - with pytest.raises(AirbyteTracedException, match="oh no!") as exc: + with pytest.raises(AirbyteTracedException) as exc: list(source.read(logger, {}, catalog)) - assert exc.value.message == "my message" + assert "oh no!" in exc.value.message GLOBAL_EMITTED_AT = 1 @@ -358,6 +372,22 @@ def _as_state(state_data: Dict[str, Any], stream_name: str = "", per_stream_stat return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data)) +def _as_error_trace(stream: str, error_message: str, internal_message: Optional[str], failure_type: Optional[FailureType], stack_trace: Optional[str]) -> AirbyteMessage: + trace_message = AirbyteTraceMessage( + emitted_at=datetime.datetime.now().timestamp() * 1000.0, + type=TraceType.ERROR, + error=AirbyteErrorTraceMessage( + stream_descriptor=StreamDescriptor(name=stream), + message=error_message, + internal_message=internal_message, + failure_type=failure_type, + stack_trace=stack_trace, + ), + ) + + return AirbyteMessage(type=MessageType.TRACE, trace=trace_message) + + def _configured_stream(stream: Stream, sync_mode: SyncMode): return ConfiguredAirbyteStream( stream=stream.as_airbyte_stream(), @@ -1168,21 +1198,27 @@ def test_checkpoint_state_from_stream_instance(): ) -def test_continue_sync_with_failed_streams(mocker): +@pytest.mark.parametrize( + "exception_to_raise,expected_error_message,expected_internal_message", + [ + pytest.param(AirbyteTracedException(message="I was born only to crash like Icarus"), "I was born only to crash like Icarus", None, id="test_raises_traced_exception"), + pytest.param(Exception("Generic connector error message"), "Something went wrong in the connector. See the logs for more details.", "Generic connector error message", id="test_raises_generic_exception"), + ] +) +def test_continue_sync_with_failed_streams(mocker, exception_to_raise, expected_error_message, expected_internal_message): """ - Tests that running a sync for a connector with multiple streams and continue_sync_on_stream_failure enabled continues - syncing even when one stream fails with an error. + Tests that running a sync for a connector with multiple streams will continue syncing when one stream fails + with an error. This source does not override the default behavior defined in the AbstractSource class. """ stream_output = [{"k1": "v1"}, {"k2": "v2"}] s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") - s2 = StreamRaisesException() + s2 = StreamRaisesException(exception_to_raise=exception_to_raise) s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") mocker.patch.object(MockStream, "get_json_schema", return_value={}) mocker.patch.object(StreamRaisesException, "get_json_schema", return_value={}) src = MockSource(streams=[s1, s2, s3]) - mocker.patch.object(MockSource, "continue_sync_on_stream_failure", return_value=True) catalog = ConfiguredAirbyteCatalog( streams=[ _configured_stream(s1, SyncMode.full_refresh), @@ -1199,6 +1235,7 @@ def test_continue_sync_with_failed_streams(mocker): _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), + _as_error_trace("lamentations", expected_error_message, expected_internal_message, FailureType.system_error, None), _as_stream_status("s3", AirbyteStreamStatus.STARTED), _as_stream_status("s3", AirbyteStreamStatus.RUNNING), *_as_records("s3", stream_output), @@ -1206,26 +1243,73 @@ def test_continue_sync_with_failed_streams(mocker): ] ) - messages = [] with pytest.raises(AirbyteTracedException) as exc: - # We can't use list comprehension or list() here because we are still raising a final exception for the - # failed streams and that disrupts parsing the generator into the messages emitted before - for message in src.read(logger, {}, catalog): - messages.append(message) + messages = [_remove_stack_trace(message) for message in src.read(logger, {}, catalog)] + messages = _fix_emitted_at(messages) + + assert expected == messages - messages = _fix_emitted_at(messages) - assert expected == messages assert "lamentations" in exc.value.message -def test_stop_sync_with_failed_streams(mocker): +def test_continue_sync_source_override_false(mocker): """ - Tests that running a sync for a connector with multiple streams and continue_sync_on_stream_failure disabled stops - syncing once a stream fails with an error. + Tests that running a sync for a connector explicitly overriding the default AbstractSource.stop_sync_on_stream_failure + property to be False which will continue syncing stream even if one encountered an exception. """ + update_secrets(["API_KEY_VALUE"]) + stream_output = [{"k1": "v1"}, {"k2": "v2"}] s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") - s2 = StreamRaisesException() + s2 = StreamRaisesException(exception_to_raise=AirbyteTracedException(message="I was born only to crash like Icarus")) + s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(StreamRaisesException, "get_json_schema", return_value={}) + + src = MockSourceWithStopSyncFalseOverride(streams=[s1, s2, s3]) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(s1, SyncMode.full_refresh), + _configured_stream(s2, SyncMode.full_refresh), + _configured_stream(s3, SyncMode.full_refresh), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + *_as_records("s1", stream_output), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), + _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), + _as_error_trace("lamentations", "I was born only to crash like Icarus", None, FailureType.system_error, None), + _as_stream_status("s3", AirbyteStreamStatus.STARTED), + _as_stream_status("s3", AirbyteStreamStatus.RUNNING), + *_as_records("s3", stream_output), + _as_stream_status("s3", AirbyteStreamStatus.COMPLETE), + ] + ) + + with pytest.raises(AirbyteTracedException) as exc: + messages = [_remove_stack_trace(message) for message in src.read(logger, {}, catalog)] + messages = _fix_emitted_at(messages) + + assert expected == messages + + assert "lamentations" in exc.value.message + + +def test_sync_error_trace_messages_obfuscate_secrets(mocker): + """ + Tests that exceptions emitted as trace messages by a source have secrets properly sanitized + """ + update_secrets(["API_KEY_VALUE"]) + + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") + s2 = StreamRaisesException(exception_to_raise=AirbyteTracedException(message="My api_key value API_KEY_VALUE flew too close to the sun.")) s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") mocker.patch.object(MockStream, "get_json_schema", return_value={}) @@ -1248,15 +1332,71 @@ def test_stop_sync_with_failed_streams(mocker): _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), + _as_error_trace("lamentations", "My api_key value **** flew too close to the sun.", None, FailureType.system_error, None), + _as_stream_status("s3", AirbyteStreamStatus.STARTED), + _as_stream_status("s3", AirbyteStreamStatus.RUNNING), + *_as_records("s3", stream_output), + _as_stream_status("s3", AirbyteStreamStatus.COMPLETE), ] ) - messages = [] - with pytest.raises(AirbyteTracedException): - # We can't use list comprehension or list() here because we are still raising a final exception for the - # failed streams and that disrupts parsing the generator into the messages emitted before - for message in src.read(logger, {}, catalog): - messages.append(message) + with pytest.raises(AirbyteTracedException) as exc: + messages = [_remove_stack_trace(message) for message in src.read(logger, {}, catalog)] + messages = _fix_emitted_at(messages) - messages = _fix_emitted_at(messages) - assert expected == messages + assert expected == messages + + assert "lamentations" in exc.value.message + + +def test_continue_sync_with_failed_streams_with_override_false(mocker): + """ + Tests that running a sync for a connector with multiple streams and stop_sync_on_stream_failure enabled stops + the sync when one stream fails with an error. + """ + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") + s2 = StreamRaisesException(AirbyteTracedException(message="I was born only to crash like Icarus")) + s3 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s3") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(StreamRaisesException, "get_json_schema", return_value={}) + + src = MockSource(streams=[s1, s2, s3]) + mocker.patch.object(MockSource, "stop_sync_on_stream_failure", return_value=True) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(s1, SyncMode.full_refresh), + _configured_stream(s2, SyncMode.full_refresh), + _configured_stream(s3, SyncMode.full_refresh), + ] + ) + + expected = _fix_emitted_at( + [ + _as_stream_status("s1", AirbyteStreamStatus.STARTED), + _as_stream_status("s1", AirbyteStreamStatus.RUNNING), + *_as_records("s1", stream_output), + _as_stream_status("s1", AirbyteStreamStatus.COMPLETE), + _as_stream_status("lamentations", AirbyteStreamStatus.STARTED), + _as_stream_status("lamentations", AirbyteStreamStatus.INCOMPLETE), + _as_error_trace("lamentations", "I was born only to crash like Icarus", None, FailureType.system_error, None), + ] + ) + + with pytest.raises(AirbyteTracedException) as exc: + messages = [_remove_stack_trace(message) for message in src.read(logger, {}, catalog)] + messages = _fix_emitted_at(messages) + + assert expected == messages + + assert "lamentations" in exc.value.message + + +def _remove_stack_trace(message: AirbyteMessage) -> AirbyteMessage: + """ + Helper method that removes the stack trace from Airbyte trace messages to make asserting against expected records easier + """ + if message.trace and message.trace.error and message.trace.error.stack_trace: + message.trace.error.stack_trace = None + return message diff --git a/airbyte-cdk/python/unit_tests/sources/test_integration_source.py b/airbyte-cdk/python/unit_tests/sources/test_integration_source.py index 048864f12c908f..17628a0263cde3 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_integration_source.py +++ b/airbyte-cdk/python/unit_tests/sources/test_integration_source.py @@ -2,7 +2,9 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +import json import os +from typing import Any, List, Mapping from unittest import mock from unittest.mock import patch @@ -22,9 +24,9 @@ "deployment_mode, url_base, expected_records, expected_error", [ pytest.param("CLOUD", "https://airbyte.com/api/v1/", [], None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], ValueError, id="test_cloud_read_with_unsecured_url"), - pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], AirbyteTracedException, id="test_cloud_read_with_private_endpoint"), - pytest.param("CLOUD", "https://localhost:80/api/v1/", [], AirbyteTracedException, id="test_cloud_read_with_localhost"), + pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], "system_error", id="test_cloud_read_with_unsecured_url"), + pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], "config_error", id="test_cloud_read_with_private_endpoint"), + pytest.param("CLOUD", "https://localhost:80/api/v1/", [], "config_error", id="test_cloud_read_with_localhost"), pytest.param("OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint"), pytest.param("OSS", "https://172.20.105.99/api/v1/", [], None, id="test_oss_read_with_private_endpoint"), ], @@ -37,8 +39,10 @@ def test_external_request_source(capsys, deployment_mode, url_base, expected_rec with mock.patch.object(HttpTestStream, "url_base", url_base): args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: - with pytest.raises(expected_error): + with pytest.raises(AirbyteTracedException): launch(source, args) + messages = [json.loads(line) for line in capsys.readouterr().out.splitlines()] + assert contains_error_trace_message(messages, expected_error) else: launch(source, args) @@ -47,14 +51,14 @@ def test_external_request_source(capsys, deployment_mode, url_base, expected_rec "deployment_mode, token_refresh_url, expected_records, expected_error", [ pytest.param("CLOUD", "https://airbyte.com/api/v1/", [], None, id="test_cloud_read_with_public_endpoint"), - pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], ValueError, id="test_cloud_read_with_unsecured_url"), - pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], AirbyteTracedException, id="test_cloud_read_with_private_endpoint"), + pytest.param("CLOUD", "http://unsecured.com/api/v1/", [], "system_error", id="test_cloud_read_with_unsecured_url"), + pytest.param("CLOUD", "https://172.20.105.99/api/v1/", [], "config_error", id="test_cloud_read_with_private_endpoint"), pytest.param("OSS", "https://airbyte.com/api/v1/", [], None, id="test_oss_read_with_public_endpoint"), pytest.param("OSS", "https://172.20.105.99/api/v1/", [], None, id="test_oss_read_with_private_endpoint"), ], ) @patch.object(requests.Session, "send", fixture_mock_send) -def test_external_oauth_request_source(deployment_mode, token_refresh_url, expected_records, expected_error): +def test_external_oauth_request_source(capsys, deployment_mode, token_refresh_url, expected_records, expected_error): oauth_authenticator = SourceFixtureOauthAuthenticator( client_id="nora", client_secret="hae_sung", refresh_token="arthur", token_refresh_endpoint=token_refresh_url ) @@ -63,7 +67,20 @@ def test_external_oauth_request_source(deployment_mode, token_refresh_url, expec with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): # clear=True clears the existing os.environ dict args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: - with pytest.raises(expected_error): + with pytest.raises(AirbyteTracedException): launch(source, args) + messages = [json.loads(line) for line in capsys.readouterr().out.splitlines()] + assert contains_error_trace_message(messages, expected_error) else: launch(source, args) + + +def contains_error_trace_message(messages: List[Mapping[str, Any]], expected_error: str) -> bool: + for message in messages: + if message.get("type") != "TRACE": + continue + elif message.get("trace").get("type") != "ERROR": + continue + elif message.get("trace").get("error").get("failure_type") == expected_error: + return True + return False diff --git a/airbyte-cdk/python/unit_tests/sources/test_source_read.py b/airbyte-cdk/python/unit_tests/sources/test_source_read.py index 752c4640d3edd0..dd08c4d18dacf7 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_source_read.py +++ b/airbyte-cdk/python/unit_tests/sources/test_source_read.py @@ -343,7 +343,7 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_e source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger) config = {} catalog = _create_configured_catalog(source._streams) - messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, RuntimeError) + messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException) messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, RuntimeError) expected_messages = [