Skip to content

Commit

Permalink
🎉 Source Salesforce new ContentDocumentLink stream (#33342)
Browse files Browse the repository at this point in the history
  • Loading branch information
midavadim committed Dec 15, 2023
1 parent 2b6bcfc commit 0859afc
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: b117307c-14b6-41aa-9422-947e34922962
dockerImageTag: 2.2.0
dockerImageTag: 2.2.1
dockerRepository: airbyte/source-salesforce
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
githubIssueLabel: source-salesforce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
"AppTabMember",
"CollaborationGroupRecord",
"ColorDefinition",
"ContentDocumentLink",
"ContentFolderItem",
"ContentFolderMember",
"DataStatistics",
Expand Down Expand Up @@ -129,6 +128,19 @@
"UserRecordAccess",
]

PARENT_SALESFORCE_OBJECTS = {
# parent_name - name of parent stream
# field - in each parent record, which is needed for stream slice
# schema_minimal - required for getting proper class name full_refresh/incremental, rest/bulk for parent stream
"ContentDocumentLink": {
"parent_name": "ContentDocument",
"field": "Id",
"schema_minimal": {
"properties": {"Id": {"type": ["string", "null"]}, "SystemModstamp": {"type": ["string", "null"], "format": "date-time"}}
},
}
}

# The following objects are not supported by the Bulk API. Listed objects are version specific.
UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS = [
"AcceptedEventRelation",
Expand Down Expand Up @@ -184,6 +196,7 @@
UNSUPPORTED_FILTERING_STREAMS = [
"ApiEvent",
"BulkApiResultEventStore",
"ContentDocumentLink",
"EmbeddedServiceDetail",
"EmbeddedServiceLabel",
"FormulaFunction",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@
from dateutil.relativedelta import relativedelta
from requests import codes, exceptions # type: ignore[import]

from .api import UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .streams import BulkIncrementalSalesforceStream, BulkSalesforceStream, Describe, IncrementalRestSalesforceStream, RestSalesforceStream
from .api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .streams import (
BulkIncrementalSalesforceStream,
BulkSalesforceStream,
BulkSalesforceSubStream,
Describe,
IncrementalRestSalesforceStream,
RestSalesforceStream,
RestSalesforceSubStream,
)

_DEFAULT_CONCURRENCY = 10
_MAX_CONCURRENCY = 10
Expand Down Expand Up @@ -79,8 +87,10 @@ def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) ->
return True, None

@classmethod
def _get_api_type(cls, stream_name: str, properties: Mapping[str, Any], force_use_bulk_api: bool) -> str:
def _get_api_type(cls, stream_name: str, json_schema: Mapping[str, Any], force_use_bulk_api: bool) -> str:
"""Get proper API type: rest or bulk"""
# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
properties = json_schema.get("properties", {})
properties_not_supported_by_bulk = {
key: value for key, value in properties.items() if value.get("format") == "base64" or "object" in value["type"]
}
Expand All @@ -97,50 +107,87 @@ def _get_api_type(cls, stream_name: str, properties: Mapping[str, Any], force_us
return "rest"
return "bulk"

@classmethod
def _get_stream_type(cls, stream_name: str, api_type: str):
"""Get proper stream class: full_refresh, incremental or substream
SubStreams (like ContentDocumentLink) do not support incremental sync because of query restrictions, look here:
https://developer.salesforce.com/docs/atlas.en-us.object_reference.meta/object_reference/sforce_api_objects_contentdocumentlink.htm
"""
parent_name = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("parent_name")
if api_type == "rest":
full_refresh = RestSalesforceSubStream if parent_name else RestSalesforceStream
incremental = IncrementalRestSalesforceStream
elif api_type == "bulk":
full_refresh = BulkSalesforceSubStream if parent_name else BulkSalesforceStream
incremental = BulkIncrementalSalesforceStream
else:
raise Exception(f"Stream {stream_name} cannot be processed by REST or BULK API.")
return full_refresh, incremental

@classmethod
def prepare_stream(cls, stream_name: str, json_schema, sobject_options, sf_object, authenticator, config):
"""Choose proper stream class: syncMode(full_refresh/incremental), API type(Rest/Bulk), SubStream"""
pk, replication_key = sf_object.get_pk_and_replication_key(json_schema)
stream_kwargs = {
"stream_name": stream_name,
"schema": json_schema,
"pk": pk,
"sobject_options": sobject_options,
"sf_api": sf_object,
"authenticator": authenticator,
"start_date": config.get("start_date"),
}

api_type = cls._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False))
full_refresh, incremental = cls._get_stream_type(stream_name, api_type)
if replication_key and stream_name not in UNSUPPORTED_FILTERING_STREAMS:
stream_class = incremental
stream_kwargs["replication_key"] = replication_key
else:
stream_class = full_refresh

return stream_class, stream_kwargs

@classmethod
def generate_streams(
cls,
config: Mapping[str, Any],
stream_objects: Mapping[str, Any],
sf_object: Salesforce,
) -> List[Stream]:
""" "Generates a list of stream by their names. It can be used for different tests too"""
"""Generates a list of stream by their names. It can be used for different tests too"""
authenticator = TokenAuthenticator(sf_object.access_token)
stream_properties = sf_object.generate_schemas(stream_objects)
schemas = sf_object.generate_schemas(stream_objects)
default_args = [sf_object, authenticator, config]
streams = []
for stream_name, sobject_options in stream_objects.items():
streams_kwargs = {"sobject_options": sobject_options}
selected_properties = stream_properties.get(stream_name, {}).get("properties", {})

api_type = cls._get_api_type(stream_name, selected_properties, config.get("force_use_bulk_api", False))
if api_type == "rest":
full_refresh, incremental = RestSalesforceStream, IncrementalRestSalesforceStream
elif api_type == "bulk":
full_refresh, incremental = BulkSalesforceStream, BulkIncrementalSalesforceStream
else:
raise Exception(f"Stream {stream_name} cannot be processed by REST or BULK API.")

json_schema = stream_properties.get(stream_name, {})
pk, replication_key = sf_object.get_pk_and_replication_key(json_schema)
streams_kwargs.update(dict(sf_api=sf_object, pk=pk, stream_name=stream_name, schema=json_schema, authenticator=authenticator))
if replication_key and stream_name not in UNSUPPORTED_FILTERING_STREAMS:
start_date = config.get(
"start_date", (datetime.now() - relativedelta(years=cls.START_DATE_OFFSET_IN_YEARS)).strftime(cls.DATETIME_FORMAT)
)
stream = incremental(**streams_kwargs, replication_key=replication_key, start_date=start_date)
else:
stream = full_refresh(**streams_kwargs)
json_schema = schemas.get(stream_name, {})

stream_class, kwargs = cls.prepare_stream(stream_name, json_schema, sobject_options, *default_args)

parent_name = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("parent_name")
if parent_name:
# get minimal schema required for getting proper class name full_refresh/incremental, rest/bulk
parent_schema = PARENT_SALESFORCE_OBJECTS.get(stream_name, {}).get("schema_minimal")
parent_class, parent_kwargs = cls.prepare_stream(parent_name, parent_schema, sobject_options, *default_args)
kwargs["parent"] = parent_class(**parent_kwargs)

stream = stream_class(**kwargs)

api_type = cls._get_api_type(stream_name, json_schema, config.get("force_use_bulk_api", False))
if api_type == "rest" and not stream.primary_key and stream.too_many_properties:
logger.warning(
f"Can not instantiate stream {stream_name}. "
f"It is not supported by the BULK API and can not be implemented via REST because the number of its properties "
f"exceeds the limit and it lacks a primary key."
f"Can not instantiate stream {stream_name}. It is not supported by the BULK API and can not be "
"implemented via REST because the number of its properties exceeds the limit and it lacks a primary key."
)
continue
streams.append(stream)
return streams

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
if not config.get("start_date"):
config["start_date"] = (datetime.now() - relativedelta(years=self.START_DATE_OFFSET_IN_YEARS)).strftime(self.DATETIME_FORMAT)
sf = self._get_sf_object(config)
stream_objects = sf.get_validated_streams(config=config, catalog=self.catalog)
streams = self.generate_streams(config, stream_objects, sf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from airbyte_cdk.models import ConfiguredAirbyteCatalog, FailureType, SyncMode
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.core import Stream, StreamData
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from airbyte_cdk.utils import AirbyteTracedException
from numpy import nan
from pendulum import DateTime # type: ignore[attr-defined]
from requests import codes, exceptions
from requests.models import PreparedRequest

from .api import UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .api import PARENT_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .availability_strategy import SalesforceAvailabilityStrategy
from .exceptions import SalesforceException, TmpFileIOError
from .rate_limiting import default_backoff_handler
Expand All @@ -45,14 +45,29 @@ class SalesforceStream(HttpStream, ABC):
encoding = DEFAULT_ENCODING

def __init__(
self, sf_api: Salesforce, pk: str, stream_name: str, sobject_options: Mapping[str, Any] = None, schema: dict = None, **kwargs
self,
sf_api: Salesforce,
pk: str,
stream_name: str,
sobject_options: Mapping[str, Any] = None,
schema: dict = None,
start_date=None,
**kwargs,
):
super().__init__(**kwargs)
self.sf_api = sf_api
self.pk = pk
self.stream_name = stream_name
self.schema: Mapping[str, Any] = schema # type: ignore[assignment]
self.sobject_options = sobject_options
self.start_date = self.format_start_date(start_date)

@staticmethod
def format_start_date(start_date: Optional[str]) -> Optional[str]:
"""Transform the format `2021-07-25` into the format `2021-07-25T00:00:00Z`"""
if start_date:
return pendulum.parse(start_date).strftime("%Y-%m-%dT%H:%M:%SZ") # type: ignore[attr-defined,no-any-return]
return None

@property
def max_properties_length(self) -> int:
Expand Down Expand Up @@ -141,14 +156,18 @@ def request_params(
Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm
"""
if next_page_token:
"""
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
"""
# If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
return {}

property_chunk = property_chunk or {}
query = f"SELECT {','.join(property_chunk.keys())} FROM {self.name} "

if self.name in PARENT_SALESFORCE_OBJECTS:
# add where clause: " WHERE ContentDocumentId IN ('06905000000NMXXXXX', ...)"
parent_field = PARENT_SALESFORCE_OBJECTS[self.name]["field"]
parent_ids = [f"'{parent_record[parent_field]}'" for parent_record in stream_slice["parents"]]
query += f" WHERE ContentDocumentId IN ({','.join(parent_ids)})"

if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.primary_key} ASC"

Expand Down Expand Up @@ -282,6 +301,30 @@ def _fetch_next_page_for_chunk(
return request, response


class BatchedSubStream(HttpSubStream):
SLICE_BATCH_SIZE = 200

def stream_slices(
self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
"""Instead of yielding one parent record at a time, make stream slice contain a batch of parent records.
It allows to get <SLICE_BATCH_SIZE> records by one requests (instead of only one).
"""
batched_slice = []
for stream_slice in super().stream_slices(sync_mode, cursor_field, stream_state):
if len(batched_slice) == self.SLICE_BATCH_SIZE:
yield {"parents": batched_slice}
batched_slice = []
batched_slice.append(stream_slice["parent"])
if batched_slice:
yield {"parents": batched_slice}


class RestSalesforceSubStream(BatchedSubStream, RestSalesforceStream):
pass


class BulkSalesforceStream(SalesforceStream):
DEFAULT_WAIT_TIMEOUT_SECONDS = 86400 # 24-hour bulk job running time
MAX_CHECK_INTERVAL_SECONDS = 2.0
Expand Down Expand Up @@ -542,6 +585,12 @@ def request_params(
if next_page_token:
query += next_page_token["next_token"]

if self.name in PARENT_SALESFORCE_OBJECTS:
# add where clause: " WHERE ContentDocumentId IN ('06905000000NMXXXXX', '06905000000Mxp7XXX', ...)"
parent_field = PARENT_SALESFORCE_OBJECTS[self.name]["field"]
parent_ids = [f"'{parent_record[parent_field]}'" for parent_record in stream_slice["parents"]]
query += f" WHERE ContentDocumentId IN ({','.join(parent_ids)})"

return {"q": query}

def read_records(
Expand Down Expand Up @@ -605,6 +654,10 @@ def get_standard_instance(self) -> SalesforceStream:
return new_cls(**stream_kwargs)


class BulkSalesforceSubStream(BatchedSubStream, BulkSalesforceStream):
pass


@BulkSalesforceStream.transformer.registerCustomTransform
def transform_empty_string_to_none(instance: Any, schema: Any):
"""
Expand All @@ -622,17 +675,9 @@ class IncrementalRestSalesforceStream(RestSalesforceStream, ABC):
STREAM_SLICE_STEP = 30
_slice = None

def __init__(self, replication_key: str, start_date: Optional[str], **kwargs):
def __init__(self, replication_key: str, **kwargs):
super().__init__(**kwargs)
self.replication_key = replication_key
self.start_date = self.format_start_date(start_date)

@staticmethod
def format_start_date(start_date: Optional[str]) -> Optional[str]:
"""Transform the format `2021-07-25` into the format `2021-07-25T00:00:00Z`"""
if start_date:
return pendulum.parse(start_date).strftime("%Y-%m-%dT%H:%M:%SZ") # type: ignore[attr-defined,no-any-return]
return None

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None
Expand Down

0 comments on commit 0859afc

Please sign in to comment.