Skip to content

Commit

Permalink
Source Salesforce: run all syncs concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
clnoll committed Jan 30, 2024
1 parent acd26ac commit 87b7566
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 58 deletions.
Expand Up @@ -17,6 +17,7 @@
HERE = Path(__file__).parent
_ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []})
_ANY_CONFIG = {}
_ANY_STATE = {}


@pytest.fixture(name="input_config")
Expand All @@ -35,7 +36,7 @@ def get_stream(input_config: Mapping[str, Any], stream_name: str) -> Stream:
stream_cls = type("a", (object,), {"name": stream_name})
configured_stream_cls = type("b", (object,), {"stream": stream_cls(), "sync_mode": "full_refresh"})
catalog_cls = type("c", (object,), {"streams": [configured_stream_cls()]})
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.catalog = catalog_cls()
return source.streams(input_config)[0]

Expand All @@ -46,12 +47,12 @@ def get_any_real_stream(input_config: Mapping[str, Any]) -> Stream:

def test_not_queryable_stream(caplog, input_config):
stream = get_any_real_stream(input_config)
url = f"{stream.sf_api.instance_url}/services/data/{stream.sf_api.version}/jobs/query"
url = f"{stream._legacy_stream.sf_api.instance_url}/services/data/{stream._legacy_stream.sf_api.version}/jobs/query"

# test non queryable BULK streams
query = "Select Id, Subject from ActivityHistory"
with caplog.at_level(logging.WARNING):
assert stream.create_stream_job(query, url) is None, "this stream should be skipped"
assert stream._legacy_stream.create_stream_job(query, url) is None, "this stream should be skipped"

# check logs
assert "is not queryable" in caplog.records[-1].message
Expand Down
Expand Up @@ -10,7 +10,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: b117307c-14b6-41aa-9422-947e34922962
dockerImageTag: 2.2.2
dockerImageTag: 2.3.0
dockerRepository: airbyte/source-salesforce
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
githubIssueLabel: source-salesforce
Expand Down
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-salesforce/setup.py
Expand Up @@ -5,7 +5,7 @@

from setuptools import find_packages, setup

MAIN_REQUIREMENTS = ["airbyte-cdk~=0.55.2", "pandas"]
MAIN_REQUIREMENTS = ["airbyte-cdk~=0.59.0", "pandas"]

TEST_REQUIREMENTS = ["freezegun", "pytest~=6.1", "pytest-mock~=3.6", "requests-mock~=1.9.3", "pytest-timeout"]

Expand Down
Expand Up @@ -16,10 +16,12 @@
def _get_source(args: List[str]):
catalog_path = AirbyteEntrypoint.extract_catalog(args)
config_path = AirbyteEntrypoint.extract_config(args)
state_path = AirbyteEntrypoint.extract_state(args)
try:
return SourceSalesforce(
SourceSalesforce.read_catalog(catalog_path) if catalog_path else None,
SourceSalesforce.read_config(config_path) if config_path else None,
SourceSalesforce.read_state(state_path) if state_path else None,
)
except Exception as error:
print(
Expand Down
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

import pendulum
import requests
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.logger import AirbyteLogFormatter
Expand All @@ -14,9 +15,10 @@
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import InMemoryMessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.cursor import NoopCursor
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, NoopCursor
from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
Expand Down Expand Up @@ -50,7 +52,7 @@ class SourceSalesforce(ConcurrentSourceAdapter):

message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[logger.level]))

def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional[Mapping[str, Any]], **kwargs):
def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional[Mapping[str, Any]], state: Optional[TState], **kwargs):
if config:
concurrency_level = min(config.get("num_workers", _DEFAULT_CONCURRENCY), _MAX_CONCURRENCY)
else:
Expand All @@ -61,6 +63,7 @@ def __init__(self, catalog: Optional[ConfiguredAirbyteCatalog], config: Optional
)
super().__init__(concurrent_source)
self.catalog = catalog
self.state = state

@staticmethod
def _get_sf_object(config: Mapping[str, Any]) -> Salesforce:
Expand Down Expand Up @@ -192,16 +195,40 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
stream_objects = sf.get_validated_streams(config=config, catalog=self.catalog)
streams = self.generate_streams(config, stream_objects, sf)
streams.append(Describe(sf_api=sf, catalog=self.catalog))
# TODO: incorporate state & ConcurrentCursor when we support incremental
state_manager = ConnectorStateManager(stream_instance_map={s.name: s for s in streams}, state=self.state)

configured_streams = []

for stream in streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh:
configured_streams.append(StreamFacade.create_from_stream(stream, self, logger, None, NoopCursor()))
cursor = NoopCursor()
state = None
else:
configured_streams.append(stream)
cursor_field_key = stream.cursor_field or ""
if not isinstance(cursor_field_key, str):
raise AssertionError(f"A string cursor field key is required, but got {cursor_field_key}.")
cursor_field = CursorField(cursor_field_key)
legacy_state = state_manager.get_stream_state(stream.name, stream.namespace)
cursor = ConcurrentCursor(
stream.name,
stream.namespace,
legacy_state,
self.message_repository,
state_manager,
stream.state_converter,
cursor_field,
self._get_slice_boundary_fields(stream, state_manager),
config["start_date"],
)
state = cursor.state

configured_streams.append(StreamFacade.create_from_stream(stream, self, logger, state, cursor))
return configured_streams

def _get_slice_boundary_fields(self, stream: Stream, state_manager: ConnectorStateManager) -> Optional[Tuple[str, str]]:
return ("start_date", "end_date")

def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
if self.catalog:
for catalog_stream in self.catalog.streams:
Expand Down
Expand Up @@ -11,13 +11,15 @@
import uuid
from abc import ABC
from contextlib import closing
from datetime import datetime, timedelta
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

import pandas as pd
import pendulum
import requests # type: ignore[import]
from airbyte_cdk.models import ConfiguredAirbyteCatalog, FailureType, SyncMode
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import IsoMillisConcurrentStreamStateConverter
from airbyte_cdk.sources.streams.core import Stream, StreamData
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
Expand All @@ -37,9 +39,11 @@
csv.field_size_limit(CSV_FIELD_SIZE_LIMIT)

DEFAULT_ENCODING = "utf-8"
LOOKBACK_SECONDS = 600 # based on https://trailhead.salesforce.com/trailblazer-community/feed/0D54V00007T48TASAZ


class SalesforceStream(HttpStream, ABC):
state_converter = IsoMillisConcurrentStreamStateConverter()
page_size = 2000
transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
encoding = DEFAULT_ENCODING
Expand Down Expand Up @@ -108,6 +112,16 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]:
return f"After {self.max_retries} retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later."
return super().get_error_display_message(exception)

def get_start_date_from_state(self, stream_state: Mapping[str, Any] = None) -> datetime:
if self.state_converter.is_state_message_compatible(stream_state):
# stream_state is in the concurrent format
if stream_state.get("slices", []):
return stream_state["slices"][0]["end"]
elif stream_state and not self.state_converter.is_state_message_compatible(stream_state):
# stream_state has not been converted to the concurrent format; this is not expected
return pendulum.parse(stream_state.get(self.cursor_field), tz="UTC")
return pendulum.parse(self.start_date, tz="UTC")


class PropertyChunk:
"""
Expand All @@ -127,6 +141,8 @@ def __init__(self, properties: Mapping[str, Any]):


class RestSalesforceStream(SalesforceStream):
state_converter = IsoMillisConcurrentStreamStateConverter()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.primary_key or not self.too_many_properties
Expand Down Expand Up @@ -302,6 +318,7 @@ def _fetch_next_page_for_chunk(


class BatchedSubStream(HttpSubStream):
state_converter = IsoMillisConcurrentStreamStateConverter()
SLICE_BATCH_SIZE = 200

def stream_slices(
Expand Down Expand Up @@ -684,7 +701,8 @@ def stream_slices(
) -> Iterable[Optional[Mapping[str, Any]]]:
start, end = (None, None)
now = pendulum.now(tz="UTC")
initial_date = pendulum.parse((stream_state or {}).get(self.cursor_field, self.start_date), tz="UTC")
assert LOOKBACK_SECONDS is not None and LOOKBACK_SECONDS >= 0
initial_date = self.get_start_date_from_state(stream_state) - timedelta(seconds=LOOKBACK_SECONDS)

slice_number = 1
while not end == now:
Expand Down Expand Up @@ -768,6 +786,7 @@ def request_params(


class Describe(Stream):
state_converter = IsoMillisConcurrentStreamStateConverter()
"""
Stream of sObjects' (Salesforce Objects) describe:
https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_sobject_describe.htm
Expand Down
Expand Up @@ -7,9 +7,9 @@
import io
import logging
import re
from datetime import datetime
from datetime import datetime, timedelta
from typing import List
from unittest.mock import Mock
from unittest.mock import Mock, patch

import freezegun
import pendulum
Expand Down Expand Up @@ -37,6 +37,7 @@

_ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []})
_ANY_CONFIG = {}
_ANY_STATE = None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -65,7 +66,7 @@
def test_login_authentication_error_handler(
stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg, is_config_error
):
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
logger = logging.getLogger("airbyte")
requests_mock.register_uri(
"POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code
Expand Down Expand Up @@ -345,7 +346,7 @@ def test_encoding_symbols(stream_config, stream_api, chunk_size, content_type_he
def test_check_connection_rate_limit(
stream_config, login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg
):
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
logger = logging.getLogger("airbyte")

with requests_mock.Mocker() as m:
Expand Down Expand Up @@ -382,7 +383,7 @@ def test_rate_limit_bulk(stream_config, stream_api, bulk_catalog, state):
stream_1.page_size = 6
stream_1.state_checkpoint_interval = 5

source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.streams = Mock()
source.streams.return_value = streams
logger = logging.getLogger("airbyte")
Expand Down Expand Up @@ -438,7 +439,7 @@ def test_rate_limit_rest(stream_config, stream_api, rest_catalog, state):
stream_1.state_checkpoint_interval = 3
configure_request_params_mock(stream_1, stream_2)

source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.streams = Mock()
source.streams.return_value = [stream_1, stream_2]

Expand Down Expand Up @@ -623,7 +624,7 @@ def test_forwarding_sobject_options(stream_config, stream_names, catalog_stream_
],
},
)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.catalog = catalog
streams = source.streams(config=stream_config)
expected_names = catalog_stream_names if catalog else stream_names
Expand All @@ -638,28 +639,6 @@ def test_forwarding_sobject_options(stream_config, stream_names, catalog_stream_
return


@pytest.mark.parametrize(
"stream_names,catalog_stream_names,",
(
(
["stream_1", "stream_2", "Describe"],
None,
),
(
["stream_1", "stream_2"],
["stream_1", "stream_2", "Describe"],
),
(
["stream_1", "stream_2", "stream_3", "Describe"],
["stream_1", "Describe"],
),
),
)
def test_unspecified_and_incremental_streams_are_not_concurrent(stream_config, stream_names, catalog_stream_names) -> None:
for stream in _get_streams(stream_config, stream_names, catalog_stream_names, SyncMode.incremental):
assert isinstance(stream, (SalesforceStream, Describe))


@pytest.mark.parametrize(
"stream_names,catalog_stream_names,",
(
Expand Down Expand Up @@ -723,7 +702,7 @@ def _get_streams(stream_config, stream_names, catalog_stream_names, sync_type) -
],
},
)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.catalog = catalog
return source.streams(config=stream_config)

Expand Down Expand Up @@ -886,29 +865,41 @@ def test_bulk_stream_error_on_wait_for_job(requests_mock, stream_config, stream_


@freezegun.freeze_time("2023-01-01")
def test_bulk_stream_slices(stream_config_date_format, stream_api):
@pytest.mark.parametrize(
"lookback, expect_error",
[(None, True), (0, False), (10, False), (-1, True)],
ids=["lookback-is-none", "lookback-is-0", "lookback-is-valid", "lookback-is-negative"],
)
def test_bulk_stream_slices(stream_config_date_format, stream_api, lookback, expect_error):
stream: BulkIncrementalSalesforceStream = generate_stream("FakeBulkStream", stream_config_date_format, stream_api)
stream_slices = list(stream.stream_slices(sync_mode=SyncMode.full_refresh))
expected_slices = []
today = pendulum.today(tz="UTC")
start_date = pendulum.parse(stream.start_date, tz="UTC")
while start_date < today:
expected_slices.append(
{
"start_date": start_date.isoformat(timespec="milliseconds"),
"end_date": min(today, start_date.add(days=stream.STREAM_SLICE_STEP)).isoformat(timespec="milliseconds"),
}
)
start_date = start_date.add(days=stream.STREAM_SLICE_STEP)
assert expected_slices == stream_slices
with patch("source_salesforce.streams.LOOKBACK_SECONDS", lookback):
if expect_error:
with pytest.raises(AssertionError):
list(stream.stream_slices(sync_mode=SyncMode.full_refresh))
else:
stream_slices = list(stream.stream_slices(sync_mode=SyncMode.full_refresh))

expected_slices = []
today = pendulum.today(tz="UTC")
start_date = pendulum.parse(stream.start_date, tz="UTC") - timedelta(seconds=lookback)
while start_date < today:
expected_slices.append(
{
"start_date": start_date.isoformat(timespec="milliseconds"),
"end_date": min(today, start_date.add(days=stream.STREAM_SLICE_STEP)).isoformat(timespec="milliseconds"),
}
)
start_date = start_date.add(days=stream.STREAM_SLICE_STEP)
assert expected_slices == stream_slices


@freezegun.freeze_time("2023-04-01")
def test_bulk_stream_request_params_states(stream_config_date_format, stream_api, bulk_catalog, requests_mock):
"""Check that request params ignore records cursor and use start date from slice ONLY"""
stream_config_date_format.update({"start_date": "2023-01-01"})
stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config_date_format, stream_api)

source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG)
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
source.streams = Mock()
source.streams.return_value = [stream]

Expand Down Expand Up @@ -938,7 +929,8 @@ def test_bulk_stream_request_params_states(stream_config_date_format, stream_api
logger = logging.getLogger("airbyte")
state = {"Account": {"LastModifiedDate": "2023-01-01T10:10:10.000Z"}}
bulk_catalog.streams.pop(1)
result = [i for i in source.read(logger=logger, config=stream_config_date_format, catalog=bulk_catalog, state=state)]
with patch("source_salesforce.streams.LOOKBACK_SECONDS", 0):
result = [i for i in source.read(logger=logger, config=stream_config_date_format, catalog=bulk_catalog, state=state)]

actual_state_values = [item.state.data.get("Account").get(stream.cursor_field) for item in result if item.type == Type.STATE]
# assert request params
Expand Down

0 comments on commit 87b7566

Please sign in to comment.