diff --git a/airbyte-integrations/connectors/source-salesforce/integration_tests/bulk_error_test.py b/airbyte-integrations/connectors/source-salesforce/integration_tests/bulk_error_test.py index d51d68d957dfb..598f1cdb513b0 100644 --- a/airbyte-integrations/connectors/source-salesforce/integration_tests/bulk_error_test.py +++ b/airbyte-integrations/connectors/source-salesforce/integration_tests/bulk_error_test.py @@ -17,6 +17,7 @@ HERE = Path(__file__).parent _ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []}) _ANY_CONFIG = {} +_ANY_STATE = {} @pytest.fixture(name="input_config") @@ -35,7 +36,7 @@ def get_stream(input_config: Mapping[str, Any], stream_name: str) -> Stream: stream_cls = type("a", (object,), {"name": stream_name}) configured_stream_cls = type("b", (object,), {"stream": stream_cls(), "sync_mode": "full_refresh"}) catalog_cls = type("c", (object,), {"streams": [configured_stream_cls()]}) - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.catalog = catalog_cls() return source.streams(input_config)[0] @@ -46,12 +47,12 @@ def get_any_real_stream(input_config: Mapping[str, Any]) -> Stream: def test_not_queryable_stream(caplog, input_config): stream = get_any_real_stream(input_config) - url = f"{stream.sf_api.instance_url}/services/data/{stream.sf_api.version}/jobs/query" + url = f"{stream._legacy_stream.sf_api.instance_url}/services/data/{stream._legacy_stream.sf_api.version}/jobs/query" # test non queryable BULK streams query = "Select Id, Subject from ActivityHistory" with caplog.at_level(logging.WARNING): - assert stream.create_stream_job(query, url) is None, "this stream should be skipped" + assert stream._legacy_stream.create_stream_job(query, url) is None, "this stream should be skipped" # check logs assert "is not queryable" in caplog.records[-1].message diff --git a/airbyte-integrations/connectors/source-salesforce/metadata.yaml b/airbyte-integrations/connectors/source-salesforce/metadata.yaml index 498448722b27c..9688083f23188 100644 --- a/airbyte-integrations/connectors/source-salesforce/metadata.yaml +++ b/airbyte-integrations/connectors/source-salesforce/metadata.yaml @@ -10,7 +10,7 @@ data: connectorSubtype: api connectorType: source definitionId: b117307c-14b6-41aa-9422-947e34922962 - dockerImageTag: 2.2.2 + dockerImageTag: 2.3.0 dockerRepository: airbyte/source-salesforce documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce githubIssueLabel: source-salesforce diff --git a/airbyte-integrations/connectors/source-salesforce/setup.py b/airbyte-integrations/connectors/source-salesforce/setup.py index 4add132d7cb51..87df4dc169fc3 100644 --- a/airbyte-integrations/connectors/source-salesforce/setup.py +++ b/airbyte-integrations/connectors/source-salesforce/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup -MAIN_REQUIREMENTS = ["airbyte-cdk~=0.55.2", "pandas"] +MAIN_REQUIREMENTS = ["airbyte-cdk~=0.58.10", "pandas"] TEST_REQUIREMENTS = ["freezegun", "pytest~=6.1", "pytest-mock~=3.6", "requests-mock~=1.9.3", "pytest-timeout"] diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/run.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/run.py index 7fe23dc8958c9..07c8c7ce83ab5 100644 --- a/airbyte-integrations/connectors/source-salesforce/source_salesforce/run.py +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/run.py @@ -16,10 +16,12 @@ def _get_source(args: List[str]): catalog_path = AirbyteEntrypoint.extract_catalog(args) config_path = AirbyteEntrypoint.extract_config(args) + state_path = AirbyteEntrypoint.extract_state(args) try: return SourceSalesforce( SourceSalesforce.read_catalog(catalog_path) if catalog_path else None, SourceSalesforce.read_config(config_path) if config_path else None, + SourceSalesforce.read_state(state_path) if state_path else None, ) except Exception as error: print( diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/source.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/source.py index 30eea954dfe0f..975899f1d4914 100644 --- a/airbyte-integrations/connectors/source-salesforce/source_salesforce/source.py +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/source.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union +import pendulum import requests from airbyte_cdk import AirbyteLogger from airbyte_cdk.logger import AirbyteLogFormatter @@ -14,9 +15,10 @@ from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import InMemoryMessageRepository +from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade -from airbyte_cdk.sources.streams.concurrent.cursor import NoopCursor +from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -50,7 +52,7 @@ class SourceSalesforce(ConcurrentSourceAdapter): message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[logger.level])) - def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional[Mapping[str, Any]], **kwargs): + def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional[Mapping[str, Any]], state: Optional[TState], **kwargs): if config: concurrency_level = min(config.get("num_workers", _DEFAULT_CONCURRENCY), _MAX_CONCURRENCY) else: @@ -61,6 +63,7 @@ def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional ) super().__init__(concurrent_source) self.catalog = catalog + self.state = state @staticmethod def _get_sf_object(config: Mapping[str, Any]) -> Salesforce: @@ -192,16 +195,39 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: stream_objects = sf.get_validated_streams(config=config, catalog=self.catalog) streams = self.generate_streams(config, stream_objects, sf) streams.append(Describe(sf_api=sf, catalog=self.catalog)) - # TODO: incorporate state & ConcurrentCursor when we support incremental + state_manager = ConnectorStateManager(stream_instance_map={s.name: s for s in streams}, state=self.state) + configured_streams = [] + for stream in streams: sync_mode = self._get_sync_mode_from_catalog(stream) if sync_mode == SyncMode.full_refresh: - configured_streams.append(StreamFacade.create_from_stream(stream, self, logger, None, NoopCursor())) + cursor = NoopCursor() + state = None else: - configured_streams.append(stream) + cursor_field_key = stream.cursor_field or "" + if not isinstance(cursor_field_key, str): + raise AssertionError(f"A string cursor field key is required, but got {cursor_field_key}.") + cursor_field = CursorField(cursor_field_key) + legacy_state = state_manager.get_stream_state(stream.name, stream.namespace) + cursor = ConcurrentCursor( + stream.name, + stream.namespace, + legacy_state, + self.message_repository, + state_manager, + stream.state_converter, + cursor_field, + self._get_slice_boundary_fields(stream, state_manager), + config["start_date"], + ) + + configured_streams.append(StreamFacade.create_from_stream(stream, self, logger, cursor.state, cursor)) return configured_streams + def _get_slice_boundary_fields(self, stream: Stream, state_manager: ConnectorStateManager) -> Optional[Tuple[str, str]]: + return ("start_date", "end_date") + def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]: if self.catalog: for catalog_stream in self.catalog.streams: diff --git a/airbyte-integrations/connectors/source-salesforce/source_salesforce/streams.py b/airbyte-integrations/connectors/source-salesforce/source_salesforce/streams.py index 34c03d1caa948..9bd81a4d16dde 100644 --- a/airbyte-integrations/connectors/source-salesforce/source_salesforce/streams.py +++ b/airbyte-integrations/connectors/source-salesforce/source_salesforce/streams.py @@ -11,6 +11,7 @@ import uuid from abc import ABC from contextlib import closing +from datetime import datetime, timedelta from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union import pandas as pd @@ -18,6 +19,7 @@ import requests # type: ignore[import] from airbyte_cdk.models import ConfiguredAirbyteCatalog, FailureType, SyncMode from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy +from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import IsoMillisConcurrentStreamStateConverter from airbyte_cdk.sources.streams.core import Stream, StreamData from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer @@ -37,9 +39,11 @@ csv.field_size_limit(CSV_FIELD_SIZE_LIMIT) DEFAULT_ENCODING = "utf-8" +LOOKBACK_SECONDS = 600 # based on https://trailhead.salesforce.com/trailblazer-community/feed/0D54V00007T48TASAZ class SalesforceStream(HttpStream, ABC): + state_converter = IsoMillisConcurrentStreamStateConverter() page_size = 2000 transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization) encoding = DEFAULT_ENCODING @@ -108,6 +112,16 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: return f"After {self.max_retries} retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later." return super().get_error_display_message(exception) + def get_start_date_from_state(self, stream_state: Mapping[str, Any] = None) -> datetime: + if self.state_converter.is_state_message_compatible(stream_state): + # stream_state is in the concurrent format + if stream_state.get("slices", []): + return stream_state["slices"][0]["end"] + elif stream_state and not self.state_converter.is_state_message_compatible(stream_state): + # stream_state has not been converted to the concurrent format; this is not expected + return pendulum.parse(stream_state.get(self.cursor_field), tz="UTC") + return pendulum.parse(self.start_date, tz="UTC") + class PropertyChunk: """ @@ -127,6 +141,8 @@ def __init__(self, properties: Mapping[str, Any]): class RestSalesforceStream(SalesforceStream): + state_converter = IsoMillisConcurrentStreamStateConverter() + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.primary_key or not self.too_many_properties @@ -302,6 +318,7 @@ def _fetch_next_page_for_chunk( class BatchedSubStream(HttpSubStream): + state_converter = IsoMillisConcurrentStreamStateConverter() SLICE_BATCH_SIZE = 200 def stream_slices( @@ -684,7 +701,8 @@ def stream_slices( ) -> Iterable[Optional[Mapping[str, Any]]]: start, end = (None, None) now = pendulum.now(tz="UTC") - initial_date = pendulum.parse((stream_state or {}).get(self.cursor_field, self.start_date), tz="UTC") + assert LOOKBACK_SECONDS is not None and LOOKBACK_SECONDS >= 0 + initial_date = self.get_start_date_from_state(stream_state) - timedelta(seconds=LOOKBACK_SECONDS) slice_number = 1 while not end == now: @@ -768,6 +786,7 @@ def request_params( class Describe(Stream): + state_converter = IsoMillisConcurrentStreamStateConverter() """ Stream of sObjects' (Salesforce Objects) describe: https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_sobject_describe.htm diff --git a/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test.py b/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test.py index 8f87e2bd58cdd..a87f73862dd5a 100644 --- a/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test.py +++ b/airbyte-integrations/connectors/source-salesforce/unit_tests/api_test.py @@ -7,9 +7,9 @@ import io import logging import re -from datetime import datetime +from datetime import datetime, timedelta from typing import List -from unittest.mock import Mock +from unittest.mock import Mock, patch import freezegun import pendulum @@ -37,6 +37,7 @@ _ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []}) _ANY_CONFIG = {} +_ANY_STATE = None @pytest.mark.parametrize( @@ -65,7 +66,7 @@ def test_login_authentication_error_handler( stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg, is_config_error ): - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) logger = logging.getLogger("airbyte") requests_mock.register_uri( "POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code @@ -345,7 +346,7 @@ def test_encoding_symbols(stream_config, stream_api, chunk_size, content_type_he def test_check_connection_rate_limit( stream_config, login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg ): - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) logger = logging.getLogger("airbyte") with requests_mock.Mocker() as m: @@ -382,7 +383,7 @@ def test_rate_limit_bulk(stream_config, stream_api, bulk_catalog, state): stream_1.page_size = 6 stream_1.state_checkpoint_interval = 5 - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.streams = Mock() source.streams.return_value = streams logger = logging.getLogger("airbyte") @@ -438,7 +439,7 @@ def test_rate_limit_rest(stream_config, stream_api, rest_catalog, state): stream_1.state_checkpoint_interval = 3 configure_request_params_mock(stream_1, stream_2) - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.streams = Mock() source.streams.return_value = [stream_1, stream_2] @@ -623,7 +624,7 @@ def test_forwarding_sobject_options(stream_config, stream_names, catalog_stream_ ], }, ) - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.catalog = catalog streams = source.streams(config=stream_config) expected_names = catalog_stream_names if catalog else stream_names @@ -638,28 +639,6 @@ def test_forwarding_sobject_options(stream_config, stream_names, catalog_stream_ return -@pytest.mark.parametrize( - "stream_names,catalog_stream_names,", - ( - ( - ["stream_1", "stream_2", "Describe"], - None, - ), - ( - ["stream_1", "stream_2"], - ["stream_1", "stream_2", "Describe"], - ), - ( - ["stream_1", "stream_2", "stream_3", "Describe"], - ["stream_1", "Describe"], - ), - ), -) -def test_unspecified_and_incremental_streams_are_not_concurrent(stream_config, stream_names, catalog_stream_names) -> None: - for stream in _get_streams(stream_config, stream_names, catalog_stream_names, SyncMode.incremental): - assert isinstance(stream, (SalesforceStream, Describe)) - - @pytest.mark.parametrize( "stream_names,catalog_stream_names,", ( @@ -723,7 +702,7 @@ def _get_streams(stream_config, stream_names, catalog_stream_names, sync_type) - ], }, ) - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.catalog = catalog return source.streams(config=stream_config) @@ -886,21 +865,33 @@ def test_bulk_stream_error_on_wait_for_job(requests_mock, stream_config, stream_ @freezegun.freeze_time("2023-01-01") -def test_bulk_stream_slices(stream_config_date_format, stream_api): +@pytest.mark.parametrize( + "lookback, expect_error", + [(None, True), (0, False), (10, False), (-1, True)], + ids=["lookback-is-none", "lookback-is-0", "lookback-is-valid", "lookback-is-negative"], +) +def test_bulk_stream_slices(stream_config_date_format, stream_api, lookback, expect_error): stream: BulkIncrementalSalesforceStream = generate_stream("FakeBulkStream", stream_config_date_format, stream_api) - stream_slices = list(stream.stream_slices(sync_mode=SyncMode.full_refresh)) - expected_slices = [] - today = pendulum.today(tz="UTC") - start_date = pendulum.parse(stream.start_date, tz="UTC") - while start_date < today: - expected_slices.append( - { - "start_date": start_date.isoformat(timespec="milliseconds"), - "end_date": min(today, start_date.add(days=stream.STREAM_SLICE_STEP)).isoformat(timespec="milliseconds"), - } - ) - start_date = start_date.add(days=stream.STREAM_SLICE_STEP) - assert expected_slices == stream_slices + with patch("source_salesforce.streams.LOOKBACK_SECONDS", lookback): + if expect_error: + with pytest.raises(AssertionError): + list(stream.stream_slices(sync_mode=SyncMode.full_refresh)) + else: + stream_slices = list(stream.stream_slices(sync_mode=SyncMode.full_refresh)) + + expected_slices = [] + today = pendulum.today(tz="UTC") + start_date = pendulum.parse(stream.start_date, tz="UTC") - timedelta(seconds=lookback) + while start_date < today: + expected_slices.append( + { + "start_date": start_date.isoformat(timespec="milliseconds"), + "end_date": min(today, start_date.add(days=stream.STREAM_SLICE_STEP)).isoformat(timespec="milliseconds"), + } + ) + start_date = start_date.add(days=stream.STREAM_SLICE_STEP) + assert expected_slices == stream_slices + @freezegun.freeze_time("2023-04-01") def test_bulk_stream_request_params_states(stream_config_date_format, stream_api, bulk_catalog, requests_mock): @@ -908,7 +899,7 @@ def test_bulk_stream_request_params_states(stream_config_date_format, stream_api stream_config_date_format.update({"start_date": "2023-01-01"}) stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config_date_format, stream_api) - source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG) + source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE) source.streams = Mock() source.streams.return_value = [stream] @@ -938,7 +929,8 @@ def test_bulk_stream_request_params_states(stream_config_date_format, stream_api logger = logging.getLogger("airbyte") state = {"Account": {"LastModifiedDate": "2023-01-01T10:10:10.000Z"}} bulk_catalog.streams.pop(1) - result = [i for i in source.read(logger=logger, config=stream_config_date_format, catalog=bulk_catalog, state=state)] + with patch("source_salesforce.streams.LOOKBACK_SECONDS", 0): + result = [i for i in source.read(logger=logger, config=stream_config_date_format, catalog=bulk_catalog, state=state)] actual_state_values = [item.state.data.get("Account").get(stream.cursor_field) for item in result if item.type == Type.STATE] # assert request params diff --git a/docs/integrations/sources/salesforce.md b/docs/integrations/sources/salesforce.md index 282fc09a901cc..af1262e58411e 100644 --- a/docs/integrations/sources/salesforce.md +++ b/docs/integrations/sources/salesforce.md @@ -193,7 +193,8 @@ Now that you have set up the Salesforce source connector, check out the followin | Version | Date | Pull Request | Subject | |:--------|:-----------|:---------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------| -| 2.2.2 | 2024-01-04 | [33936](https://github.com/airbytehq/airbyte/pull/33936) | Prepare for airbyte-lib | +| 2.3.0 | 2023-12-15 | [33522](https://github.com/airbytehq/airbyte/pull/33522) | Sync streams concurrently in all sync modes | +| 2.2.2 | 2024-01-04 | [33936](https://github.com/airbytehq/airbyte/pull/33936) | Prepare for airbyte-lib | | 2.2.1 | 2023-12-12 | [33342](https://github.com/airbytehq/airbyte/pull/33342) | Added new ContentDocumentLink stream | | 2.2.0 | 2023-12-12 | [33350](https://github.com/airbytehq/airbyte/pull/33350) | Sync streams concurrently on full refresh | | 2.1.6 | 2023-11-28 | [32535](https://github.com/airbytehq/airbyte/pull/32535) | Run full refresh syncs concurrently |