Skip to content

Commit

Permalink
Add discover to entrypoint wrapper (#39396)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 committed Jun 11, 2024
1 parent 721ca46 commit 9b1d720
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 45 deletions.
87 changes: 59 additions & 28 deletions airbyte-cdk/python/airbyte_cdk/test/entrypoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def analytics_messages(self) -> List[AirbyteMessage]:
def errors(self) -> List[AirbyteMessage]:
return self._get_trace_message_by_trace_type(TraceType.ERROR)

@property
def catalog(self) -> AirbyteMessage:
catalog = self._get_message_by_types([Type.CATALOG])
if len(catalog) != 1:
raise ValueError(f"Expected exactly one catalog but got {len(catalog)}")
return catalog[0]

def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
status_messages = map(
lambda message: message.trace.stream_status.status,
Expand All @@ -109,6 +116,53 @@ def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[Airbyt
return [message for message in self._get_message_by_types([Type.TRACE]) if message.trace.type == trace_type]


def _run_command(source: Source, args: List[str], expecting_exception: bool = False) -> EntrypointOutput:
log_capture_buffer = StringIO()
stream_handler = logging.StreamHandler(log_capture_buffer)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(AirbyteLogFormatter())
parent_logger = logging.getLogger("")
parent_logger.addHandler(stream_handler)

parsed_args = AirbyteEntrypoint.parse_args(args)

source_entrypoint = AirbyteEntrypoint(source)
messages = []
uncaught_exception = None
try:
for message in source_entrypoint.run(parsed_args):
messages.append(message)
except Exception as exception:
if not expecting_exception:
print("Printing unexpected error from entrypoint_wrapper")
print("".join(traceback.format_exception(None, exception, exception.__traceback__)))
uncaught_exception = exception

captured_logs = log_capture_buffer.getvalue().split("\n")[:-1]

parent_logger.removeHandler(stream_handler)

return EntrypointOutput(messages + captured_logs, uncaught_exception)


def discover(
source: Source,
config: Mapping[str, Any],
expecting_exception: bool = False,
) -> EntrypointOutput:
"""
config must be json serializable
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
provide expecting_exception=True so that the test output logs are cleaner
"""

with tempfile.TemporaryDirectory() as tmp_directory:
tmp_directory_path = Path(tmp_directory)
config_file = make_file(tmp_directory_path / "config.json", config)

return _run_command(source, ["discover", "--config", config_file, "--debug"], expecting_exception)


def read(
source: Source,
config: Mapping[str, Any],
Expand All @@ -122,21 +176,16 @@ def read(
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
provide expecting_exception=True so that the test output logs are cleaner
"""
log_capture_buffer = StringIO()
stream_handler = logging.StreamHandler(log_capture_buffer)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(AirbyteLogFormatter())
parent_logger = logging.getLogger("")
parent_logger.addHandler(stream_handler)

with tempfile.TemporaryDirectory() as tmp_directory:
tmp_directory_path = Path(tmp_directory)
config_file = make_file(tmp_directory_path / "config.json", config)
catalog_file = make_file(tmp_directory_path / "catalog.json", catalog.json())
args = [
"read",
"--config",
make_file(tmp_directory_path / "config.json", config),
config_file,
"--catalog",
make_file(tmp_directory_path / "catalog.json", catalog.json()),
catalog_file,
]
if state is not None:
args.extend(
Expand All @@ -145,26 +194,8 @@ def read(
make_file(tmp_directory_path / "state.json", f"[{','.join([stream_state.json() for stream_state in state])}]"),
]
)
args.append("--debug")
source_entrypoint = AirbyteEntrypoint(source)
parsed_args = source_entrypoint.parse_args(args)

messages = []
uncaught_exception = None
try:
for message in source_entrypoint.run(parsed_args):
messages.append(message)
except Exception as exception:
if not expecting_exception:
print("Printing unexpected error from entrypoint_wrapper")
print("".join(traceback.format_exception(None, exception, exception.__traceback__)))
uncaught_exception = exception

captured_logs = log_capture_buffer.getvalue().split("\n")[:-1]

parent_logger.removeHandler(stream_handler)

return EntrypointOutput(messages + captured_logs, uncaught_exception)
return _run_command(source, args, expecting_exception)


def make_file(path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]]) -> str:
Expand Down
116 changes: 99 additions & 17 deletions airbyte-cdk/python/unit_tests/test/test_entrypoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import json
import logging
import os
from typing import Any, Iterator, List, Mapping
from typing import Any, Iterator, List, Mapping, Optional
from unittest import TestCase
from unittest.mock import Mock, patch

from airbyte_cdk.sources.abstract_source import AbstractSource
from airbyte_cdk.test.entrypoint_wrapper import read
from airbyte_cdk.test.entrypoint_wrapper import discover, read
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_protocol.models import (
AirbyteAnalyticsTraceMessage,
AirbyteCatalog,
AirbyteErrorTraceMessage,
AirbyteLogMessage,
AirbyteMessage,
Expand Down Expand Up @@ -48,6 +49,10 @@ def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteM
)


_A_CATALOG_MESSAGE = AirbyteMessage(
type=Type.CATALOG,
catalog=AirbyteCatalog(streams=[]),
)
_A_RECORD = AirbyteMessage(
type=Type.RECORD, record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0)
)
Expand Down Expand Up @@ -110,17 +115,93 @@ def _validate_tmp_catalog(expected, file_path) -> None:
assert ConfiguredAirbyteCatalog.parse_file(file_path) == expected


def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog, expected_state):
def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog: Optional[Any] = None, expected_state: Optional[Any] = None):
def _validate_tmp_files(self):
_validate_tmp_json_file(expected_config, entrypoint.return_value.parse_args.call_args.args[0][2])
_validate_tmp_catalog(expected_catalog, entrypoint.return_value.parse_args.call_args.args[0][4])
_validate_tmp_json_file(expected_state, entrypoint.return_value.parse_args.call_args.args[0][6])
_validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2])
if expected_catalog:
_validate_tmp_catalog(expected_catalog, entrypoint.parse_args.call_args.args[0][4])
if expected_state:
_validate_tmp_json_file(expected_state, entrypoint.parse_args.call_args.args[0][6])
return entrypoint.return_value.run.return_value

return _validate_tmp_files


class EntrypointWrapperTest(TestCase):
class EntrypointWrapperDiscoverTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_parameters(self, entrypoint):
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG)

discover(self._a_source, _A_CONFIG)

entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "discover"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_files_are_temporary(self, entrypoint):
discover(self._a_source, _A_CONFIG)

assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_discover_when_discover_then_output_has_logs(self, entrypoint):
def _do_some_logging(self):
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
return entrypoint.return_value.run.return_value

entrypoint.return_value.run.side_effect = _do_some_logging

output = discover(self._a_source, _A_CONFIG)

assert len(output.logs) == 1
assert output.logs[0].log.message == _A_LOG_MESSAGE

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_record_when_discover_then_output_has_record(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_CATALOG_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert output.catalog == _A_CATALOG_MESSAGE

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_log_when_discover_then_output_has_log(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
output = discover(self._a_source, _A_CONFIG)
assert output.logs == [_A_LOG]

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_trace_message_when_discover_then_output_has_trace_messages(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert output.analytics_messages == [_AN_ANALYTIC_MESSAGE]

@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_unexpected_exception_when_discover_then_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
discover(self._a_source, _A_CONFIG)
assert print_mock.call_count > 0

@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_discover_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
discover(self._a_source, _A_CONFIG, expecting_exception=True)
assert print_mock.call_count == 0

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("An error")
output = discover(self._a_source, _A_CONFIG)
assert output.errors


class EntrypointWrapperReadTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()

Expand All @@ -131,19 +212,20 @@ def test_when_read_then_ensure_parameters(self, entrypoint):
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)

entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.return_value.parse_args.return_value)
assert entrypoint.return_value.parse_args.call_count == 1
assert entrypoint.return_value.parse_args.call_args.args[0][0] == "read"
assert entrypoint.return_value.parse_args.call_args.args[0][1] == "--config"
assert entrypoint.return_value.parse_args.call_args.args[0][3] == "--catalog"
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "read"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
assert entrypoint.parse_args.call_args.args[0][3] == "--catalog"
assert entrypoint.parse_args.call_args.args[0][5] == "--state"

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_read_then_ensure_files_are_temporary(self, entrypoint):
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)

assert not os.path.exists(entrypoint.return_value.parse_args.call_args.args[0][2])
assert not os.path.exists(entrypoint.return_value.parse_args.call_args.args[0][4])
assert not os.path.exists(entrypoint.return_value.parse_args.call_args.args[0][6])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][4])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][6])

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_run_when_read_then_output_has_logs(self, entrypoint):
Expand Down Expand Up @@ -229,12 +311,12 @@ def test_given_unexpected_exception_when_read_then_print(self, entrypoint, print
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_read_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE, expecting_exception=True)
assert print_mock.call_count == 0

@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
entrypoint.return_value.run.side_effect = ValueError("An error")
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.errors

0 comments on commit 9b1d720

Please sign in to comment.