Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE #6829] update salesforce to support partitioned state #36942

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,3 @@
[run]
omit =
source_salesforce/run.py
Expand Up @@ -12,6 +12,7 @@
import pytest
import requests
from airbyte_cdk.models import SyncMode
from airbyte_protocol.models import ConfiguredAirbyteCatalog
from source_salesforce.api import Salesforce
from source_salesforce.source import SourceSalesforce

Expand All @@ -20,6 +21,10 @@
NOTE_CONTENT = "It's the note for integration test"
UPDATED_NOTE_CONTENT = "It's the updated note for integration test"

_ANY_CATALOG = ConfiguredAirbyteCatalog.parse_obj({"streams": []})
_ANY_CONFIG = {}
_ANY_STATE = {}


@pytest.fixture(scope="module")
def input_sandbox_config():
Expand All @@ -41,7 +46,7 @@ def stream_name():

@pytest.fixture(scope="module")
def stream(input_sandbox_config, stream_name, sf):
return SourceSalesforce.generate_streams(input_sandbox_config, {stream_name: None}, sf)[0]
return SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE).generate_streams(input_sandbox_config, {stream_name: None}, sf)[0]._legacy_stream


def _encode_content(text):
Expand Down
Expand Up @@ -10,7 +10,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: b117307c-14b6-41aa-9422-947e34922962
dockerImageTag: 2.4.4
dockerImageTag: 2.4.5
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
dockerRepository: airbyte/source-salesforce
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
githubIssueLabel: source-salesforce
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -3,7 +3,7 @@ requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
version = "2.4.4"
version = "2.4.5"
name = "source-salesforce"
description = "Source implementation for Salesforce."
authors = [ "Airbyte <contact@airbyte.io>",]
Expand Down
Expand Up @@ -29,7 +29,7 @@ def handle_http_error(
if error.response.status_code in [codes.FORBIDDEN, codes.BAD_REQUEST]:
error_data = error.response.json()[0]
error_code = error_data.get("errorCode", "")
if error_code != "REQUEST_LIMIT_EXCEEDED" or error_code == "INVALID_TYPE_FOR_OPERATION":
if error_code != "REQUEST_LIMIT_EXCEEDED":
return False, f"Cannot receive data for stream '{stream.name}', error message: '{error_data.get('message')}'"
return True, None
raise error
Expand Up @@ -3,9 +3,10 @@
#

import logging
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

import isodate
import pendulum
import requests
from airbyte_cdk import AirbyteLogger
Expand All @@ -29,6 +30,7 @@

from .api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .streams import (
LOOKBACK_SECONDS,
BulkIncrementalSalesforceStream,
BulkSalesforceStream,
BulkSalesforceSubStream,
Expand Down Expand Up @@ -172,9 +174,8 @@ def prepare_stream(cls, stream_name: str, json_schema, sobject_options, sf_objec

return stream_class, stream_kwargs

@classmethod
def generate_streams(
cls,
self,
config: Mapping[str, Any],
stream_objects: Mapping[str, Any],
sf_object: Salesforce,
Expand All @@ -184,69 +185,83 @@ def generate_streams(
schemas = sf_object.generate_schemas(stream_objects)
default_args = [sf_object, authenticator, config]
streams = []
state_manager = ConnectorStateManager(stream_instance_map={s.name: s for s in streams}, state=self.state)
for stream_name, sobject_options in stream_objects.items():
json_schema = schemas.get(stream_name, {})

stream_class, kwargs = cls.prepare_stream(stream_name, json_schema, sobject_options, *default_args)
stream_class, kwargs = self.prepare_stream(stream_name, json_schema, sobject_options, *default_args)

parent_name = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("parent_name")
if parent_name:
# get minimal schema required for getting proper class name full_refresh/incremental, rest/bulk
parent_schema = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("schema_minimal")
parent_class, parent_kwargs = cls.prepare_stream(parent_name, parent_schema, sobject_options, *default_args)
parent_class, parent_kwargs = self.prepare_stream(parent_name, parent_schema, sobject_options, *default_args)
kwargs["parent"] = parent_class(**parent_kwargs)

stream = stream_class(**kwargs)

api_type = cls._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False))
api_type = self._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False))
if api_type == "rest" and not stream.primary_key and stream.too_many_properties:
logger.warning(
f"Can not instantiate stream {stream_name}. It is not supported by the BULK API and can not be "
"implemented via REST because the number of its properties exceeds the limit and it lacks a primary key."
)
continue
streams.append(stream)

streams.append(self._wrap_for_concurrency(config, stream, state_manager))
streams.append(self._wrap_for_concurrency(config, Describe(sf_api=sf_object, catalog=self.catalog), state_manager))
return streams

def _wrap_for_concurrency(self, config, stream, state_manager):
stream_slicer_cursor = self._create_stream_slicer_cursor(config, state_manager, stream)
if hasattr(stream, "set_cursor"):
stream.set_cursor(stream_slicer_cursor)
elif hasattr(stream, "parent") and hasattr(stream.parent, "set_cursor"):
stream.parent.set_cursor(stream_slicer_cursor)

if self._get_sync_mode_from_catalog(stream) == SyncMode.full_refresh:
cursor = FinalStateCursor(
stream_name=stream.name, stream_namespace=stream.namespace, message_repository=self.message_repository
)
state = None
else:
cursor = stream_slicer_cursor
state = cursor.state
return StreamFacade.create_from_stream(stream, self, logger, state, cursor)

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
if not config.get("start_date"):
config["start_date"] = (datetime.now() - relativedelta(years=self.START_DATE_OFFSET_IN_YEARS)).strftime(self.DATETIME_FORMAT)
sf = self._get_sf_object(config)
stream_objects = sf.get_validated_streams(config=config, catalog=self.catalog)
streams = self.generate_streams(config, stream_objects, sf)
streams.append(Describe(sf_api=sf, catalog=self.catalog))
state_manager = ConnectorStateManager(stream_instance_map={s.name: s for s in streams}, state=self.state)

configured_streams = []

for stream in streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh:
cursor = FinalStateCursor(
stream_name=stream.name, stream_namespace=stream.namespace, message_repository=self.message_repository
)
state = None
else:
cursor_field_key = stream.cursor_field or ""
if not isinstance(cursor_field_key, str):
raise AssertionError(f"A string cursor field key is required, but got {cursor_field_key}.")
cursor_field = CursorField(cursor_field_key)
legacy_state = state_manager.get_stream_state(stream.name, stream.namespace)
cursor = ConcurrentCursor(
stream.name,
stream.namespace,
legacy_state,
self.message_repository,
state_manager,
stream.state_converter,
cursor_field,
self._get_slice_boundary_fields(stream, state_manager),
config["start_date"],
)
state = cursor.state
return streams

configured_streams.append(StreamFacade.create_from_stream(stream, self, logger, state, cursor))
return configured_streams
def _create_stream_slicer_cursor(
self, config: Mapping[str, Any], state_manager: ConnectorStateManager, stream: Stream
) -> ConcurrentCursor:
"""
We have moved the generation of stream slices to the concurrent CDK cursor
"""
cursor_field_key = stream.cursor_field or ""
if not isinstance(cursor_field_key, str):
raise AssertionError(f"A string cursor field key is required, but got {cursor_field_key}.")
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
cursor_field = CursorField(cursor_field_key)
legacy_state = state_manager.get_stream_state(stream.name, stream.namespace)
return ConcurrentCursor(
stream.name,
stream.namespace,
legacy_state,
self.message_repository,
state_manager,
stream.state_converter,
cursor_field,
self._get_slice_boundary_fields(stream, state_manager),
datetime.fromtimestamp(pendulum.parse(config["start_date"]).timestamp(), timezone.utc),
stream.state_converter.get_end_provider(),
timedelta(seconds=LOOKBACK_SECONDS),
isodate.parse_duration(config["stream_slice_step"]) if "stream_slice_step" in config else timedelta(days=30),
)

def _get_slice_boundary_fields(self, stream: Stream, state_manager: ConnectorStateManager) -> Optional[Tuple[str, str]]:
return ("start_date", "end_date")
Expand Down
Expand Up @@ -11,14 +11,16 @@
import uuid
from abc import ABC
from contextlib import closing
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union
from datetime import timedelta
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

import backoff
import pandas as pd
import pendulum
import requests # type: ignore[import]
from airbyte_cdk.models import ConfiguredAirbyteCatalog, FailureType, SyncMode
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import IsoMillisConcurrentStreamStateConverter
from airbyte_cdk.sources.streams.core import Stream, StreamData
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
Expand All @@ -44,7 +46,7 @@


class SalesforceStream(HttpStream, ABC):
state_converter = IsoMillisConcurrentStreamStateConverter()
state_converter = IsoMillisConcurrentStreamStateConverter(is_sequential_state=False)
page_size = 2000
transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
encoding = DEFAULT_ENCODING
Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(self, properties: Mapping[str, Any]):


class RestSalesforceStream(SalesforceStream):
state_converter = IsoMillisConcurrentStreamStateConverter()
state_converter = IsoMillisConcurrentStreamStateConverter(is_sequential_state=False)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -320,7 +322,7 @@ def _fetch_next_page_for_chunk(


class BatchedSubStream(HttpSubStream):
state_converter = IsoMillisConcurrentStreamStateConverter()
state_converter = IsoMillisConcurrentStreamStateConverter(is_sequential_state=False)
SLICE_BATCH_SIZE = 200

def stream_slices(
Expand Down Expand Up @@ -705,7 +707,10 @@ def get_standard_instance(self) -> SalesforceStream:
stream_kwargs.update({"replication_key": self.replication_key, "start_date": self.start_date})
new_cls = IncrementalRestSalesforceStream

return new_cls(**stream_kwargs)
standard_instance = new_cls(**stream_kwargs)
if hasattr(standard_instance, "set_cursor"):
standard_instance.set_cursor(self._stream_slicer_cursor)
return standard_instance


class BulkSalesforceSubStream(BatchedSubStream, BulkSalesforceStream):
Expand All @@ -732,24 +737,22 @@ def __init__(self, replication_key: str, stream_slice_step: str = "P30D", **kwar
super().__init__(**kwargs)
self.replication_key = replication_key
self._stream_slice_step = stream_slice_step
self._stream_slicer_cursor = None

def set_cursor(self, cursor: Cursor) -> None:
self._stream_slicer_cursor = cursor

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
now = pendulum.now(tz="UTC")
assert LOOKBACK_SECONDS is not None and LOOKBACK_SECONDS >= 0

initial_date = self.get_start_date_from_state(stream_state) - pendulum.Duration(seconds=LOOKBACK_SECONDS)
slice_start = initial_date
while slice_start < now:
slice_end = slice_start + self.stream_slice_step
self._slice = {
if not self._stream_slicer_cursor:
raise ValueError("Cursor should be set at this point")

for slice_start, slice_end in self._stream_slicer_cursor.generate_slices():
yield {
"start_date": slice_start.isoformat(timespec="milliseconds"),
"end_date": min(slice_end, now).isoformat(timespec="milliseconds"),
"end_date": slice_end.isoformat(timespec="milliseconds"),
}
yield self._slice

slice_start += self.stream_slice_step

@property
def stream_slice_step(self) -> pendulum.Duration:
Expand Down Expand Up @@ -829,7 +832,7 @@ def request_params(


class Describe(Stream):
state_converter = IsoMillisConcurrentStreamStateConverter()
state_converter = IsoMillisConcurrentStreamStateConverter(is_sequential_state=False)
"""
Stream of sObjects' (Salesforce Objects) describe:
https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/resources_sobject_describe.htm
Expand Down