From 70d1942d8d70fa426aaf9377b781dba9432366d8 Mon Sep 17 00:00:00 2001 From: Joe Reuter Date: Fri, 9 Feb 2024 01:36:43 +0100 Subject: [PATCH] airbyte-lib: Fix telemetry for streaming (#34955) --- airbyte-lib/airbyte_lib/source.py | 22 ++++++---- .../test_source_test_fixture.py | 44 ++++++++++++------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/airbyte-lib/airbyte_lib/source.py b/airbyte-lib/airbyte_lib/source.py index bf68471f793e9b..47beb0ee0aa0f2 100644 --- a/airbyte-lib/airbyte_lib/source.py +++ b/airbyte-lib/airbyte_lib/source.py @@ -426,6 +426,8 @@ def _read_with_catalog( """ source_tracking_information = self.executor.get_telemetry_info() send_telemetry(source_tracking_information, cache_info, SyncState.STARTED) + sync_failed = False + self._processed_records = 0 # Reset the counter before we start try: with as_temp_files( [self._config, catalog.json(), json.dumps(state) if state else "[]"] @@ -449,14 +451,16 @@ def _read_with_catalog( send_telemetry( source_tracking_information, cache_info, SyncState.FAILED, self._processed_records ) + sync_failed = True raise finally: - send_telemetry( - source_tracking_information, - cache_info, - SyncState.SUCCEEDED, - self._processed_records, - ) + if not sync_failed: + send_telemetry( + source_tracking_information, + cache_info, + SyncState.SUCCEEDED, + self._processed_records, + ) def _add_to_logs(self, message: str) -> None: self._last_log_messages.append(message) @@ -479,11 +483,13 @@ def _execute(self, args: list[str]) -> Iterator[AirbyteMessage]: for line in self.executor.execute(args): try: message = AirbyteMessage.parse_raw(line) - yield message + if message.type is Type.RECORD: + self._processed_records += 1 if message.type == Type.LOG: self._add_to_logs(message.log.message) if message.type == Type.TRACE and message.trace.type == TraceType.ERROR: self._add_to_logs(message.trace.error.message) + yield message except Exception: self._add_to_logs(line) except Exception as e: @@ -500,8 +506,6 @@ def _tally_records( progress.reset(len(self._selected_stream_names or [])) for message in messages: - if message.type is Type.RECORD: - self._processed_records += 1 yield message progress.log_records_read(self._processed_records) diff --git a/airbyte-lib/tests/integration_tests/test_source_test_fixture.py b/airbyte-lib/tests/integration_tests/test_source_test_fixture.py index 8ebe0e53477e9d..2824f0db9c6f86 100644 --- a/airbyte-lib/tests/integration_tests/test_source_test_fixture.py +++ b/airbyte-lib/tests/integration_tests/test_source_test_fixture.py @@ -3,7 +3,8 @@ from collections.abc import Mapping import os import shutil -import subprocess +import itertools +from contextlib import nullcontext as does_not_raise from typing import Any from unittest.mock import Mock, call, patch import tempfile @@ -582,16 +583,30 @@ def test_airbyte_lib_version() -> None: @patch('airbyte_lib.telemetry.requests') @patch('airbyte_lib.telemetry.datetime') @pytest.mark.parametrize( - "raises, api_key, expected_state, expected_number_of_records, request_call_fails, extra_env, expected_flags", + "raises, api_key, expected_state, expected_number_of_records, request_call_fails, extra_env, expected_flags, cache_type, number_of_records_read", [ - pytest.param(True, "test_fail_during_sync", "failed", 1, False, {"CI": ""}, {"CI": False}, id="fail_during_sync"), - pytest.param(False, "test", "succeeded", 3, False, {"CI": ""}, {"CI": False}, id="succeed_during_sync"), - pytest.param(False, "test", "succeeded", 3, True, {"CI": ""}, {"CI": False}, id="fail_request_without_propagating"), - pytest.param(False, "test", "succeeded", 3, False, {"CI": ""}, {"CI": False}, id="falsy_ci_flag"), - pytest.param(False, "test", "succeeded", 3, False, {"CI": "true"}, {"CI": True}, id="truthy_ci_flag"), + pytest.param(pytest.raises(Exception), "test_fail_during_sync", "failed", 1, False, {"CI": ""}, {"CI": False}, "duckdb", None, id="fail_during_sync"), + pytest.param(does_not_raise(), "test", "succeeded", 3, False, {"CI": ""}, {"CI": False}, "duckdb", None, id="succeed_during_sync"), + pytest.param(does_not_raise(), "test", "succeeded", 3, True, {"CI": ""}, {"CI": False}, "duckdb", None,id="fail_request_without_propagating"), + pytest.param(does_not_raise(), "test", "succeeded", 3, False, {"CI": ""}, {"CI": False}, "duckdb", None,id="falsy_ci_flag"), + pytest.param(does_not_raise(), "test", "succeeded", 3, False, {"CI": "true"}, {"CI": True}, "duckdb", None,id="truthy_ci_flag"), + pytest.param(pytest.raises(Exception), "test_fail_during_sync", "failed", 1, False, {"CI": ""}, {"CI": False}, "streaming", 3, id="streaming_fail_during_sync"), + pytest.param(does_not_raise(), "test", "succeeded", 2, False, {"CI": ""}, {"CI": False}, "streaming", 2, id="streaming_succeed"), + pytest.param(does_not_raise(), "test", "succeeded", 1, False, {"CI": ""}, {"CI": False}, "streaming", 1, id="streaming_partial_read"), ], ) -def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_key: str, expected_state: str, expected_number_of_records: int, request_call_fails: bool, extra_env: dict[str, str], expected_flags: dict[str, bool]): +def test_tracking( + mock_datetime: Mock, + mock_requests: Mock, + raises, api_key: str, + expected_state: str, + expected_number_of_records: int, + request_call_fails: bool, + extra_env: dict[str, str], + expected_flags: dict[str, bool], + cache_type: str, + number_of_records_read: int +): """ Test that the telemetry is sent when the sync is successful. This is done by mocking the requests.post method and checking that it is called with the right arguments. @@ -613,12 +628,11 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke mock_post.side_effect = Exception("test exception") with patch.dict('os.environ', extra_env): - if raises: - with pytest.raises(Exception): + with raises: + if cache_type == "streaming": + list(itertools.islice(source.get_records("stream1"), number_of_records_read)) + else: source.read(cache) - else: - source.read(cache) - mock_post.assert_has_calls([ call("https://api.segment.io/v1/track", @@ -630,7 +644,7 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke "version": get_version(), "source": {'name': 'source-test', 'version': '0.0.1', 'type': 'venv'}, "state": "started", - "cache": {"type": "duckdb"}, + "cache": {"type": cache_type}, "ip": "0.0.0.0", "flags": expected_flags }, @@ -648,7 +662,7 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke "source": {'name': 'source-test', 'version': '0.0.1', 'type': 'venv'}, "state": expected_state, "number_of_records": expected_number_of_records, - "cache": {"type": "duckdb"}, + "cache": {"type": cache_type}, "ip": "0.0.0.0", "flags": expected_flags },