Skip to content

Commit

Permalink
🎉 Source Salesforce: speed up discovery >20x by leveraging parallel A…
Browse files Browse the repository at this point in the history
…PI calls (#10516)
  • Loading branch information
antixar committed Feb 28, 2022
1 parent f4d54a9 commit 2bba529
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 250 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/sonar-scan.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
sonar-token: ${{ secrets.SONAR_TOKEN }}
sonar-gcp-access-key: ${{ secrets.GCP_SONAR_SA_KEY }}
pull-request-id: "${{ github.repository }}/${{ github.event.pull_request.number }}"
remove-sonar-project: ${{ github.event.action == 'closed' }}
remove-sonar-project: true



Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@
- name: Salesforce
sourceDefinitionId: b117307c-14b6-41aa-9422-947e34922962
dockerRepository: airbyte/source-salesforce
dockerImageTag: 0.1.23
dockerImageTag: 1.0.0
documentationUrl: https://docs.airbyte.io/integrations/sources/salesforce
icon: salesforce.svg
sourceType: api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7064,7 +7064,7 @@
supportsNormalization: false
supportsDBT: false
supported_destination_sync_modes: []
- dockerImage: "airbyte/source-salesforce:0.1.23"
- dockerImage: "airbyte/source-salesforce:1.0.0"
spec:
documentationUrl: "https://docs.airbyte.com/integrations/sources/salesforce"
connectionSpecification:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ COPY source_salesforce ./source_salesforce
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]

LABEL io.airbyte.version=0.1.23
LABEL io.airbyte.version=1.0.0
LABEL io.airbyte.name=airbyte/source-salesforce
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import base64
import json
import time
from datetime import datetime
from pathlib import Path

Expand Down Expand Up @@ -71,15 +72,16 @@ def get_stream_state():

def test_update_for_deleted_record(stream):
headers = stream.authenticator.get_auth_header()
stream_state = get_stream_state()
time.sleep(1)
response = create_note(stream, headers)
assert response.status_code == 201, "Note was note created"
assert response.status_code == 201, "Note was not created"

created_note_id = response.json()["id"]

notes = set(record["Id"] for record in stream.read_records(sync_mode=None))
assert created_note_id in notes, "No created note during the sync"
assert created_note_id in notes, "The stream didn't return the note we created"

stream_state = get_stream_state()
response = delete_note(stream, created_note_id, headers)
assert response.status_code == 204, "Note was not deleted"

Expand All @@ -93,12 +95,9 @@ def test_update_for_deleted_record(stream):
assert is_note_updated, "No deleted note during the sync"
assert is_deleted, "Wrong field value for deleted note during the sync"

stream_state = get_stream_state()
time.sleep(1)
response = update_note(stream, created_note_id, headers)
assert response.status_code == 404, "Note was updated, but should not"

notes = set(record["Id"] for record in stream.read_records(sync_mode=SyncMode.incremental, stream_state=stream_state))
assert created_note_id not in notes, "Note was updated, but should not"
assert response.status_code == 404, "Expected an update to a deleted note to return 404"


def test_deleted_record(stream):
Expand Down Expand Up @@ -126,3 +125,24 @@ def test_deleted_record(stream):
assert record, "No updated note during the sync"
assert record["IsDeleted"], "Wrong field value for deleted note during the sync"
assert record["TextPreview"] == UPDATED_NOTE_CONTENT and record["TextPreview"] != NOTE_CONTENT, "Note Content was not updated"


def test_parallel_discover(input_sandbox_config):
sf = Salesforce(**input_sandbox_config)
sf.login()
stream_objects = sf.get_validated_streams(config=input_sandbox_config)

# try to load all schema with the old consecutive logic
consecutive_schemas = {}
start_time = datetime.now()
for stream_name, sobject_options in stream_objects.items():
consecutive_schemas[stream_name] = sf.generate_schema(stream_name, sobject_options)
consecutive_loading_time = (datetime.now() - start_time).total_seconds()
start_time = datetime.now()
parallel_schemas = sf.generate_schemas(stream_objects)
parallel_loading_time = (datetime.now() - start_time).total_seconds()

assert parallel_loading_time < consecutive_loading_time / 5.0, "parallel should be more than 10x faster"
assert set(consecutive_schemas.keys()) == set(parallel_schemas.keys())
for stream_name, schema in consecutive_schemas.items():
assert schema == parallel_schemas[stream_name]
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#

import concurrent.futures
from typing import Any, List, Mapping, Optional, Tuple

import requests
import requests # type: ignore[import]
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from requests.exceptions import HTTPError
from requests import adapters as request_adapters
from requests.exceptions import HTTPError, RequestException # type: ignore[import]

from .exceptions import TypeSalesforceException
from .rate_limiting import default_backoff_handler
Expand Down Expand Up @@ -173,6 +175,7 @@
class Salesforce:
logger = AirbyteLogger()
version = "v52.0"
parallel_tasks_size = 100

def __init__(
self,
Expand All @@ -182,21 +185,25 @@ def __init__(
client_secret: str = None,
is_sandbox: bool = None,
start_date: str = None,
**kwargs,
):
**kwargs: Any,
) -> None:
self.refresh_token = refresh_token
self.token = token
self.client_id = client_id
self.client_secret = client_secret
self.access_token = None
self.instance_url = None
self.instance_url = ""
self.session = requests.Session()
# Change the connection pool size. Default value is not enough for parallel tasks
adapter = request_adapters.HTTPAdapter(pool_connections=self.parallel_tasks_size, pool_maxsize=self.parallel_tasks_size)
self.session.mount("https://", adapter)

self.is_sandbox = is_sandbox in [True, "true"]
if self.is_sandbox:
self.logger.info("using SANDBOX of Salesforce")
self.start_date = start_date

def _get_standard_headers(self):
def _get_standard_headers(self) -> Mapping[str, str]:
return {"Authorization": "Bearer {}".format(self.access_token)}

def get_streams_black_list(self) -> List[str]:
Expand Down Expand Up @@ -240,7 +247,7 @@ def get_validated_streams(self, config: Mapping[str, Any], catalog: ConfiguredAi
validated_streams = [stream_name for stream_name in stream_names if self.filter_streams(stream_name)]
return {stream_name: sobject_options for stream_name, sobject_options in stream_objects.items() if stream_name in validated_streams}

@default_backoff_handler(max_tries=5, factor=15)
@default_backoff_handler(max_tries=5, factor=5)
def _make_request(
self, http_method: str, url: str, headers: dict = None, body: dict = None, stream: bool = False, params: dict = None
) -> requests.models.Response:
Expand Down Expand Up @@ -280,15 +287,39 @@ def describe(self, sobject: str = None, sobject_options: Mapping[str, Any] = Non
resp = self._make_request("GET", url, headers=headers)
if resp.status_code == 404 and sobject:
self.logger.error(f"not found a description for the sobject '{sobject}'. Sobject options: {sobject_options}")
return resp.json()
resp_json: Mapping[str, Any] = resp.json()
return resp_json

def generate_schema(self, stream_name: str = None, stream_options: Mapping[str, Any] = None) -> Mapping[str, Any]:
response = self.describe(stream_name, stream_options)
schema = {"$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "additionalProperties": True, "properties": {}}
for field in response["fields"]:
schema["properties"][field["name"]] = self.field_to_property_schema(field)
schema["properties"][field["name"]] = self.field_to_property_schema(field) # type: ignore[index]
return schema

def generate_schemas(self, stream_objects: Mapping[str, Any]) -> Mapping[str, Any]:
def load_schema(name: str, stream_options: Mapping[str, Any]) -> Tuple[str, Optional[Mapping[str, Any]], Optional[str]]:
try:
result = self.generate_schema(stream_name=name, stream_options=stream_options)
except RequestException as e:
return name, None, str(e)
return name, result, None

stream_names = list(stream_objects.keys())
# try to split all requests by chunks
stream_schemas = {}
for i in range(0, len(stream_names), self.parallel_tasks_size):
chunk_stream_names = stream_names[i : i + self.parallel_tasks_size]
with concurrent.futures.ThreadPoolExecutor(max_workers=len(chunk_stream_names)) as executor:
for stream_name, schema, err in executor.map(
lambda args: load_schema(*args), [(stream_name, stream_objects[stream_name]) for stream_name in chunk_stream_names]
):
if err:
self.logger.error(f"Loading error of the {stream_name} schema: {err}")
continue
stream_schemas[stream_name] = schema
return stream_schemas

@staticmethod
def get_pk_and_replication_key(json_schema: Mapping[str, Any]) -> Tuple[Optional[str], Optional[str]]:
fields_list = json_schema.get("properties", {}).keys()
Expand All @@ -314,13 +345,16 @@ def field_to_property_schema(field_params: Mapping[str, Any]) -> Mapping[str, An
if sf_type in STRING_TYPES:
property_schema["type"] = ["string", "null"]
elif sf_type in DATE_TYPES:
property_schema = {"type": ["string", "null"], "format": "date-time" if sf_type == "datetime" else "date"}
property_schema = {
"type": ["string", "null"],
"format": "date-time" if sf_type == "datetime" else "date", # type: ignore[dict-item]
}
elif sf_type in NUMBER_TYPES:
property_schema["type"] = ["number", "null"]
elif sf_type == "address":
property_schema = {
"type": ["object", "null"],
"properties": {
"properties": { # type: ignore[dict-item]
"street": {"type": ["null", "string"]},
"state": {"type": ["null", "string"]},
"postalCode": {"type": ["null", "string"]},
Expand All @@ -332,7 +366,7 @@ def field_to_property_schema(field_params: Mapping[str, Any]) -> Mapping[str, An
},
}
elif sf_type == "base64":
property_schema = {"type": ["string", "null"], "format": "base64"}
property_schema = {"type": ["string", "null"], "format": "base64"} # type: ignore[dict-item]
elif sf_type == "int":
property_schema["type"] = ["integer", "null"]
elif sf_type == "boolean":
Expand All @@ -346,7 +380,10 @@ def field_to_property_schema(field_params: Mapping[str, Any]) -> Mapping[str, An
elif sf_type == "location":
property_schema = {
"type": ["object", "null"],
"properties": {"longitude": {"type": ["null", "number"]}, "latitude": {"type": ["null", "number"]}},
"properties": { # type: ignore[dict-item]
"longitude": {"type": ["null", "number"]},
"latitude": {"type": ["null", "number"]},
},
}
else:
raise TypeSalesforceException("Found unsupported type: {}".format(sf_type))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
#


class TypeSalesforceException(Exception):
class SalesforceException(Exception):
"""
Default Salesforce exception.
"""


class TypeSalesforceException(SalesforceException):
"""
We use this exception for unknown input data types for Salesforce.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import backoff
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException
from requests import codes, exceptions
from requests import codes, exceptions # type: ignore[import]

TRANSIENT_EXCEPTIONS = (
DefaultBackoffException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
#

import copy
from typing import Any, Iterator, List, Mapping, MutableMapping, Tuple
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple

from airbyte_cdk import AirbyteLogger
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator
from airbyte_cdk.sources.utils.schema_helpers import split_config
from requests import codes, exceptions
from requests import codes, exceptions # type: ignore[import]

from .api import UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS, UNSUPPORTED_FILTERING_STREAMS, Salesforce
from .streams import BulkIncrementalSalesforceStream, BulkSalesforceStream, IncrementalSalesforceStream, SalesforceStream
Expand All @@ -24,17 +24,18 @@ def _get_sf_object(config: Mapping[str, Any]) -> Salesforce:
sf.login()
return sf

def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, any]:
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Optional[str]]:
try:
_ = self._get_sf_object(config)
return True, None
except exceptions.HTTPError as error:
error_data = error.response.json()[0]
error_code = error_data.get("errorCode")
if error.response.status_code == codes.FORBIDDEN and error_code == "REQUEST_LIMIT_EXCEEDED":
logger.warn(f"API Call limit is exceeded. Error message: '{error_data.get('message')}'")
return False, "API Call limit is exceeded"

return True, None

@classmethod
def generate_streams(
cls,
Expand All @@ -45,12 +46,13 @@ def generate_streams(
) -> List[Stream]:
""" "Generates a list of stream by their names. It can be used for different tests too"""
authenticator = TokenAuthenticator(sf_object.access_token)
stream_properties = sf_object.generate_schemas(stream_objects)
streams = []
for stream_name, sobject_options in stream_objects.items():
streams_kwargs = {"sobject_options": sobject_options}
stream_state = state.get(stream_name, {}) if state else {}

selected_properties = sf_object.generate_schema(stream_name, sobject_options).get("properties", {})
selected_properties = stream_properties.get(stream_name, {}).get("properties", {})
# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
properties_not_supported_by_bulk = {
key: value for key, value in selected_properties.items() if value.get("format") == "base64" or "object" in value["type"]
Expand All @@ -63,7 +65,7 @@ def generate_streams(
# Use BULK API
full_refresh, incremental = BulkSalesforceStream, BulkIncrementalSalesforceStream

json_schema = sf_object.generate_schema(stream_name, stream_objects)
json_schema = stream_properties.get(stream_name, {})
pk, replication_key = sf_object.get_pk_and_replication_key(json_schema)
streams_kwargs.update(dict(sf_api=sf_object, pk=pk, stream_name=stream_name, schema=json_schema, authenticator=authenticator))
if replication_key and stream_name not in UNSUPPORTED_FILTERING_STREAMS:
Expand All @@ -79,7 +81,11 @@ def streams(self, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog =
return self.generate_streams(config, stream_objects, sf, state=state)

def read(
self, logger: AirbyteLogger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None
self,
logger: AirbyteLogger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: Optional[MutableMapping[str, Any]] = None,
) -> Iterator[AirbyteMessage]:
"""
Overwritten to dynamically receive only those streams that are necessary for reading for significant speed gains
Expand Down
Loading

0 comments on commit 2bba529

Please sign in to comment.