Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source Salesforce: handle too many properties #22597

Merged
merged 7 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,7 @@
- name: Salesforce
sourceDefinitionId: b117307c-14b6-41aa-9422-947e34922962
dockerRepository: airbyte/source-salesforce
dockerImageTag: 2.0.0
dockerImageTag: 2.0.1
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
icon: salesforce.svg
sourceType: api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13106,7 +13106,7 @@
supportsNormalization: false
supportsDBT: false
supported_destination_sync_modes: []
- dockerImage: "airbyte/source-salesforce:2.0.0"
- dockerImage: "airbyte/source-salesforce:2.0.1"
spec:
documentationUrl: "https://docs.airbyte.com/integrations/sources/salesforce"
connectionSpecification:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ RUN pip install .

ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]

LABEL io.airbyte.version=2.0.0
LABEL io.airbyte.version=2.0.1
LABEL io.airbyte.name=airbyte/source-salesforce
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from requests import codes, exceptions # type: ignore[import]

from .api import UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .streams import BulkIncrementalSalesforceStream, BulkSalesforceStream, Describe, IncrementalSalesforceStream, SalesforceStream
from .streams import BulkIncrementalSalesforceStream, BulkSalesforceStream, Describe, IncrementalRestSalesforceStream, RestSalesforceStream


class AirbyteStopSync(AirbyteTracedException):
Expand Down Expand Up @@ -59,17 +59,10 @@ def _get_api_type(cls, stream_name, properties):
properties_not_supported_by_bulk = {
key: value for key, value in properties.items() if value.get("format") == "base64" or "object" in value["type"]
}
properties_length = len(",".join(p for p in properties))

rest_required = stream_name in UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS or properties_not_supported_by_bulk
# If we have a lot of properties we can overcome REST API URL length and get an error: "reason: URI Too Long".
# For such cases connector tries to use BULK API because it uses POST request and passes properties in the request body.
bulk_required = properties_length + 2000 > Salesforce.REQUEST_SIZE_LIMITS

if rest_required and not bulk_required:
if rest_required:
return "rest"
if not rest_required:
return "bulk"
return "bulk"

@classmethod
def generate_streams(
Expand All @@ -79,6 +72,7 @@ def generate_streams(
sf_object: Salesforce,
) -> List[Stream]:
""" "Generates a list of stream by their names. It can be used for different tests too"""
logger = logging.getLogger()
authenticator = TokenAuthenticator(sf_object.access_token)
stream_properties = sf_object.generate_schemas(stream_objects)
streams = []
Expand All @@ -88,7 +82,7 @@ def generate_streams(

api_type = cls._get_api_type(stream_name, selected_properties)
if api_type == "rest":
full_refresh, incremental = SalesforceStream, IncrementalSalesforceStream
full_refresh, incremental = RestSalesforceStream, IncrementalRestSalesforceStream
elif api_type == "bulk":
full_refresh, incremental = BulkSalesforceStream, BulkIncrementalSalesforceStream
else:
Expand All @@ -98,10 +92,17 @@ def generate_streams(
pk, replication_key = sf_object.get_pk_and_replication_key(json_schema)
streams_kwargs.update(dict(sf_api=sf_object, pk=pk, stream_name=stream_name, schema=json_schema, authenticator=authenticator))
if replication_key and stream_name not in UNSUPPORTED_FILTERING_STREAMS:
streams.append(incremental(**streams_kwargs, replication_key=replication_key, start_date=config.get("start_date")))
stream = incremental(**streams_kwargs, replication_key=replication_key, start_date=config.get("start_date"))
else:
streams.append(full_refresh(**streams_kwargs))

stream = full_refresh(**streams_kwargs)
if api_type == "rest" and not stream.primary_key and stream.too_many_properties:
logger.warning(
f"Can not instantiate stream {stream_name}. "
f"It is not supported by the BULK API and can not be implemented via REST because the number of its properties "
f"exceeds the limit and it lacks a primary key."
)
continue
streams.append(stream)
return streams

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import time
from abc import ABC
from contextlib import closing
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

import pandas as pd
import pendulum
import requests # type: ignore[import]
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.core import Stream, StreamData
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from numpy import nan
Expand All @@ -38,6 +38,7 @@ class SalesforceStream(HttpStream, ABC):
page_size = 2000
transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
encoding = DEFAULT_ENCODING
MAX_PROPERTIES_LENGTH = Salesforce.REQUEST_SIZE_LIMITS - 2000

def __init__(
self, sf_api: Salesforce, pk: str, stream_name: str, sobject_options: Mapping[str, Any] = None, schema: dict = None, **kwargs
Expand Down Expand Up @@ -65,6 +66,31 @@ def url_base(self) -> str:
def availability_strategy(self) -> Optional["AvailabilityStrategy"]:
return None

@property
def too_many_properties(self):
selected_properties = self.get_json_schema().get("properties", {})
properties_length = len(",".join(p for p in selected_properties))
return properties_length > self.MAX_PROPERTIES_LENGTH

def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
yield from response.json()["records"]

def get_json_schema(self) -> Mapping[str, Any]:
if not self.schema:
self.schema = self.sf_api.generate_schema(self.name)
return self.schema

def get_error_display_message(self, exception: BaseException) -> Optional[str]:
if isinstance(exception, exceptions.ConnectionError):
return f"After {self.max_retries} retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later."
return super().get_error_display_message(exception)


class RestSalesforceStream(SalesforceStream):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.primary_key or not self.too_many_properties

def path(self, next_page_token: Mapping[str, Any] = None, **kwargs: Any) -> str:
if next_page_token:
"""
Expand All @@ -80,7 +106,11 @@ def next_page_token(self, response: requests.Response) -> Optional[Mapping[str,
return {"next_token": next_token} if next_token else None

def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
self,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
property_chunk: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
"""
Salesforce SOQL Query: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_rest.meta/api_rest/dome_queryall.htm
Expand All @@ -91,32 +121,44 @@ def request_params(
"""
return {}

selected_properties = self.get_json_schema().get("properties", {})
query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
property_chunk = property_chunk or {}
query = f"SELECT {','.join(property_chunk.keys())} FROM {self.name} "

if self.primary_key and self.name not in UNSUPPORTED_FILTERING_STREAMS:
query += f"ORDER BY {self.primary_key} ASC"

return {"q": query}

def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
yield from response.json()["records"]
def chunk_properties(self) -> Iterable[Mapping[str, Any]]:
selected_properties = self.get_json_schema().get("properties", {})

def get_json_schema(self) -> Mapping[str, Any]:
if not self.schema:
self.schema = self.sf_api.generate_schema(self.name)
return self.schema
summary_length = 0
local_properties = {}
for property_name, value in selected_properties.items():
current_property_length = len(property_name) + 1 # properties are split with commas
if current_property_length + summary_length >= self.MAX_PROPERTIES_LENGTH:
yield local_properties
local_properties = {}
summary_length = 0

local_properties[property_name] = value
summary_length += current_property_length

if local_properties:
yield local_properties

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
) -> Iterable[StreamData]:
try:
yield from super().read_records(
sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state
yield from self._read_pages(
lambda req, res, state, _slice: self.parse_response(res, stream_slice=_slice, stream_state=state),
stream_slice,
stream_state,
)
except exceptions.HTTPError as error:
"""
Expand All @@ -135,10 +177,83 @@ def read_records(
return
raise error

def get_error_display_message(self, exception: BaseException) -> Optional[str]:
if isinstance(exception, exceptions.ConnectionError):
return f"After {self.max_retries} retries the connector has failed with a network error. It looks like Salesforce API experienced temporary instability, please try again later."
return super().get_error_display_message(exception)
def _read_pages(
self,
records_generator_fn: Callable[
[requests.PreparedRequest, requests.Response, Mapping[str, Any], Mapping[str, Any]], Iterable[StreamData]
],
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[StreamData]:
stream_state = stream_state or {}
pagination_complete = False
records = {}
next_pages = {}

while not pagination_complete:
index = 0
for index, property_chunk in enumerate(self.chunk_properties()):
request, response = self._fetch_next_page(stream_slice, stream_state, next_pages.get(index), property_chunk)
next_pages[index] = self.next_page_token(response)
chunk_page_records = records_generator_fn(request, response, stream_state, stream_slice)
if not self.too_many_properties:
# this is the case when a stream has no primary key
# (is allowed when properties length does not exceed the maximum value)
# so there would be a single iteration, therefore we may and should yield records immediately
yield from chunk_page_records
break
chunk_page_records = {record[self.primary_key]: record for record in chunk_page_records}

for record_id, record in chunk_page_records.items():
if record_id not in records:
records[record_id] = (record, 1)
continue
incomplete_record, counter = records[record_id]
incomplete_record.update(record)
counter += 1
records[record_id] = (incomplete_record, counter)

for record_id, (record, counter) in records.items():
if counter != index + 1:
# Because we make multiple calls to query N records (each call to fetch X properties of all the N records),
# there's a chance that the number of records corresponding to the query may change between the calls. This
# may result in data inconsistency. We skip such records for now and log a warning message.
self.logger.warning(
f"Inconsistent record with primary key {record_id} found. It consists of {counter} chunks instead of {index + 1}. "
f"Skipping it."
)
continue
yield record

records = {}

if not any(next_pages.values()):
pagination_complete = True

# Always return an empty generator just in case no records were ever yielded
yield from []

def _fetch_next_page(
self,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
property_chunk: Mapping[str, Any] = None,
) -> Tuple[requests.PreparedRequest, requests.Response]:
request_headers = self.request_headers(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)
request = self._create_prepared_request(
path=self.path(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token),
headers=dict(request_headers, **self.authenticator.get_auth_header()),
params=self.request_params(
stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token, property_chunk=property_chunk
),
json=self.request_body_json(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token),
data=self.request_body_data(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token),
)
request_kwargs = self.request_kwargs(stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token)

response = self._send_request(request, request_kwargs)
return request, response


class BulkSalesforceStream(SalesforceStream):
Expand Down Expand Up @@ -406,10 +521,10 @@ def get_standard_instance(self) -> SalesforceStream:
sobject_options=self.sobject_options,
authenticator=self.authenticator,
)
new_cls: Type[SalesforceStream] = SalesforceStream
new_cls: Type[SalesforceStream] = RestSalesforceStream
if isinstance(self, BulkIncrementalSalesforceStream):
stream_kwargs.update({"replication_key": self.replication_key, "start_date": self.start_date})
new_cls = IncrementalSalesforceStream
new_cls = IncrementalRestSalesforceStream

return new_cls(**stream_kwargs)

Expand All @@ -426,7 +541,7 @@ def transform_empty_string_to_none(instance: Any, schema: Any):
return instance


class IncrementalSalesforceStream(SalesforceStream, ABC):
class IncrementalRestSalesforceStream(RestSalesforceStream, ABC):
state_checkpoint_interval = 500

def __init__(self, replication_key: str, start_date: Optional[str], **kwargs):
Expand All @@ -442,20 +557,24 @@ def format_start_date(start_date: Optional[str]) -> Optional[str]:
return None

def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
self,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
property_chunk: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
if next_page_token:
"""
If `next_page_token` is set, subsequent requests use `nextRecordsUrl`, and do not include any parameters.
"""
return {}

selected_properties = self.get_json_schema().get("properties", {})
property_chunk = property_chunk or {}

stream_date = stream_state.get(self.cursor_field)
start_date = stream_date or self.start_date

query = f"SELECT {','.join(selected_properties.keys())} FROM {self.name} "
query = f"SELECT {','.join(property_chunk.keys())} FROM {self.name} "
if start_date:
query += f"WHERE {self.cursor_field} >= {start_date} "
if self.name not in UNSUPPORTED_FILTERING_STREAMS:
Expand All @@ -477,7 +596,7 @@ def get_updated_state(self, current_stream_state: MutableMapping[str, Any], late
return {self.cursor_field: latest_benchmark}


class BulkIncrementalSalesforceStream(BulkSalesforceStream, IncrementalSalesforceStream):
class BulkIncrementalSalesforceStream(BulkSalesforceStream, IncrementalRestSalesforceStream):
def next_page_token(self, last_record: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
if self.name not in UNSUPPORTED_FILTERING_STREAMS:
page_token: str = last_record[self.cursor_field]
Expand Down