Skip to content

Commit

Permalink
Source salesforce: handle japanese characters (#17001)
Browse files Browse the repository at this point in the history
* #454 oncall source salesforce: handle japanese characters

* source salesforce: upd changelog

* source salesforce: flake fix

* #454 source salesforce: adjust public interface to CDK, do not take into account state when choosing the API type

* auto-bump connector version [ci skip]

Co-authored-by: Octavia Squidington III <octavia-squidington-iii@users.noreply.github.com>
  • Loading branch information
davydov-d and octavia-squidington-iii committed Sep 22, 2022
1 parent 083b5d1 commit 596a436
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@
- name: Salesforce
sourceDefinitionId: b117307c-14b6-41aa-9422-947e34922962
dockerRepository: airbyte/source-salesforce
dockerImageTag: 1.0.15
dockerImageTag: 1.0.16
documentationUrl: https://docs.airbyte.io/integrations/sources/salesforce
icon: salesforce.svg
sourceType: api
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9379,7 +9379,7 @@
supportsNormalization: false
supportsDBT: false
supported_destination_sync_modes: []
- dockerImage: "airbyte/source-salesforce:1.0.15"
- dockerImage: "airbyte/source-salesforce:1.0.16"
spec:
documentationUrl: "https://docs.airbyte.com/integrations/sources/salesforce"
connectionSpecification:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ RUN pip install .

ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]

LABEL io.airbyte.version=1.0.15
LABEL io.airbyte.version=1.0.16
LABEL io.airbyte.name=airbyte/source-salesforce
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import copy
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk import AirbyteLogger
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator
from airbyte_cdk.sources.utils.schema_helpers import split_config
Expand Down Expand Up @@ -37,7 +37,7 @@ def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) ->
return True, None

@classmethod
def _get_api_type(cls, stream_name, properties, stream_state):
def _get_api_type(cls, stream_name, properties):
# Salesforce BULK API currently does not support loading fields with data type base64 and compound data
properties_not_supported_by_bulk = {
key: value for key, value in properties.items() if value.get("format") == "base64" or "object" in value["type"]
Expand All @@ -49,31 +49,27 @@ def _get_api_type(cls, stream_name, properties, stream_state):
# For such cases connector tries to use BULK API because it uses POST request and passes properties in the request body.
bulk_required = properties_length + 2000 > Salesforce.REQUEST_SIZE_LIMITS

if bulk_required and not rest_required:
return "bulk"
elif rest_required and not bulk_required:
if rest_required and not bulk_required:
return "rest"
elif not bulk_required and not rest_required:
return "rest" if stream_state else "bulk"
if not rest_required:
return "bulk"

@classmethod
def generate_streams(
cls,
config: Mapping[str, Any],
stream_objects: Mapping[str, Any],
sf_object: Salesforce,
state: Mapping[str, Any] = None,
) -> List[Stream]:
""" "Generates a list of stream by their names. It can be used for different tests too"""
authenticator = TokenAuthenticator(sf_object.access_token)
stream_properties = sf_object.generate_schemas(stream_objects)
streams = []
for stream_name, sobject_options in stream_objects.items():
streams_kwargs = {"sobject_options": sobject_options}
stream_state = state.get(stream_name, {}) if state else {}
selected_properties = stream_properties.get(stream_name, {}).get("properties", {})

api_type = cls._get_api_type(stream_name, selected_properties, stream_state)
api_type = cls._get_api_type(stream_name, selected_properties)
if api_type == "rest":
full_refresh, incremental = SalesforceStream, IncrementalSalesforceStream
elif api_type == "bulk":
Expand All @@ -91,10 +87,10 @@ def generate_streams(

return streams

def streams(self, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog = None, state: Mapping[str, Any] = None) -> List[Stream]:
def streams(self, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog = None) -> List[Stream]:
sf = self._get_sf_object(config)
stream_objects = sf.get_validated_streams(config=config, catalog=catalog)
streams = self.generate_streams(config, stream_objects, sf, state=state)
streams = self.generate_streams(config, stream_objects, sf)
streams.append(Describe(sf_api=sf, catalog=catalog))
return streams

Expand All @@ -103,17 +99,17 @@ def read(
logger: AirbyteLogger,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: Optional[MutableMapping[str, Any]] = None,
state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None,
) -> Iterator[AirbyteMessage]:
"""
Overwritten to dynamically receive only those streams that are necessary for reading for significant speed gains
(Salesforce has a strict API limit on requests).
"""
connector_state = copy.deepcopy(state or {})
config, internal_config = split_config(config)
# get the streams once in case the connector needs to make any queries to generate them
logger.info("Starting generating streams")
stream_instances = {s.name: s for s in self.streams(config, catalog=catalog, state=state)}
stream_instances = {s.name: s for s in self.streams(config, catalog=catalog)}
state_manager = ConnectorStateManager(stream_instance_map=stream_instances, state=state)
logger.info(f"Starting syncing {self.name}")
self._stream_to_instance_map = stream_instances
for configured_stream in catalog.streams:
Expand All @@ -128,7 +124,7 @@ def read(
logger=logger,
stream_instance=stream_instance,
configured_stream=configured_stream,
connector_state=connector_state,
state_manager=state_manager,
internal_config=internal_config,
)
except exceptions.HTTPError as error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ def execute_job(self, query: str, url: str) -> Tuple[Optional[str], Optional[str
return None, job_status
return job_full_url, job_status

def filter_null_bytes(self, s: str):
def filter_null_bytes(self, b: bytes):
"""
https://github.com/airbytehq/airbyte/issues/8300
"""
res = s.replace("\x00", "")
if len(res) < len(s):
self.logger.warning("Filter 'null' bytes from string, size reduced %d -> %d chars", len(s), len(res))
res = b.replace(b"\x00", b"")
if len(res) < len(b):
self.logger.warning("Filter 'null' bytes from string, size reduced %d -> %d chars", len(b), len(res))
return res

def download_data(self, url: str, chunk_size: float = 1024) -> os.PathLike:
Expand All @@ -292,9 +292,9 @@ def download_data(self, url: str, chunk_size: float = 1024) -> os.PathLike:
# set filepath for binary data from response
tmp_file = os.path.realpath(os.path.basename(url))
with closing(self._send_http_request("GET", f"{url}/results", stream=True)) as response:
with open(tmp_file, "w") as data_file:
with open(tmp_file, "wb") as data_file:
for chunk in response.iter_content(chunk_size=chunk_size):
data_file.writelines(self.filter_null_bytes(self.decode(chunk)))
data_file.write(self.filter_null_bytes(chunk))
# check the file exists
if os.path.isfile(tmp_file):
return tmp_file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
)


@pytest.fixture(autouse=True)
def time_sleep_mock(mocker):
time_mock = mocker.patch("time.sleep", lambda x: None)
yield time_mock


def test_bulk_sync_creation_failed(stream_config, stream_api):
stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api)
with requests_mock.Mocker() as m:
Expand Down Expand Up @@ -58,26 +52,6 @@ def test_stream_contains_unsupported_properties_by_bulk(stream_config, stream_ap
assert not isinstance(stream, BulkSalesforceStream)


def test_stream_has_state_rest_api_should_be_used(stream_config, stream_api):
"""
Stream `ActiveFeatureLicenseMetric` has state, in that case REST API stream will be used for it.
"""
stream_name = "ActiveFeatureLicenseMetric"
state = {stream_name: {"SystemModstamp": "2122-08-22T05:08:29.000Z"}}
stream = generate_stream(stream_name, stream_config, stream_api, state=state)
assert not isinstance(stream, BulkSalesforceStream)


def test_stream_has_no_state_bulk_api_should_be_used(stream_config, stream_api):
"""
Stream `ActiveFeatureLicenseMetric` has no state, in that case BULK API stream will be used for it.
"""
stream_name = "ActiveFeatureLicenseMetric"
state = {"other_stream": {"SystemModstamp": "2122-08-22T05:08:29.000Z"}}
stream = generate_stream(stream_name, stream_config, stream_api, state=state)
assert isinstance(stream, BulkSalesforceStream)


@pytest.mark.parametrize("item_number", [0, 15, 2000, 2324, 3000])
def test_bulk_sync_pagination(item_number, stream_config, stream_api):
stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api)
Expand Down Expand Up @@ -239,7 +213,7 @@ def configure_request_params_mock(stream_1, stream_2):
stream_2.request_params.return_value = {"q": "query"}


def test_rate_limit_bulk(stream_config, stream_api, configured_catalog, state):
def test_rate_limit_bulk(stream_config, stream_api, bulk_catalog, state):
"""
Connector should stop the sync if one stream reached rate limit
stream_1, stream_2, stream_3, ...
Expand Down Expand Up @@ -282,7 +256,7 @@ def test_rate_limit_bulk(stream_config, stream_api, configured_catalog, state):

m.register_uri("POST", stream.path(), creation_responses)

result = [i for i in source.read(logger=logger, config=stream_config, catalog=configured_catalog, state=state)]
result = [i for i in source.read(logger=logger, config=stream_config, catalog=bulk_catalog, state=state)]
assert stream_1.request_params.called
assert (
not stream_2.request_params.called
Expand All @@ -295,16 +269,16 @@ def test_rate_limit_bulk(stream_config, stream_api, configured_catalog, state):
assert state_record.state.data["Account"]["LastModifiedDate"] == "2021-11-05" # state checkpoint interval is 5.


def test_rate_limit_rest(stream_config, stream_api, configured_catalog, state):
def test_rate_limit_rest(stream_config, stream_api, rest_catalog, state):
"""
Connector should stop the sync if one stream reached rate limit
stream_1, stream_2, stream_3, ...
While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process.
Next streams should not be executed.
"""

stream_1: IncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api, state=state)
stream_2: IncrementalSalesforceStream = generate_stream("Asset", stream_config, stream_api, state=state)
stream_1: IncrementalSalesforceStream = generate_stream("KnowledgeArticle", stream_config, stream_api)
stream_2: IncrementalSalesforceStream = generate_stream("AcceptedEventRelation", stream_config, stream_api)

stream_1.state_checkpoint_interval = 3
configure_request_params_mock(stream_1, stream_2)
Expand Down Expand Up @@ -349,7 +323,7 @@ def test_rate_limit_rest(stream_config, stream_api, configured_catalog, state):
m.register_uri("GET", stream_1.path(), json=response_1, status_code=200)
m.register_uri("GET", next_page_url, json=response_2, status_code=403)

result = [i for i in source.read(logger=logger, config=stream_config, catalog=configured_catalog, state=state)]
result = [i for i in source.read(logger=logger, config=stream_config, catalog=rest_catalog, state=state)]

assert stream_1.request_params.called
assert (
Expand All @@ -360,14 +334,12 @@ def test_rate_limit_rest(stream_config, stream_api, configured_catalog, state):
assert len(records) == 5

state_record = [item for item in result if item.type == Type.STATE][0]
assert state_record.state.data["Account"]["LastModifiedDate"] == "2021-11-17"
assert state_record.state.data["KnowledgeArticle"]["LastModifiedDate"] == "2021-11-17"


def test_pagination_rest(stream_config, stream_api):
stream_name = "ActiveFeatureLicenseMetric"
state = {stream_name: {"SystemModstamp": "2122-08-22T05:08:29.000Z"}}

stream: SalesforceStream = generate_stream(stream_name, stream_config, stream_api, state=state)
stream_name = "AcceptedEventRelation"
stream: SalesforceStream = generate_stream(stream_name, stream_config, stream_api)
stream.DEFAULT_WAIT_TIMEOUT_SECONDS = 6 # maximum wait timeout will be 6 seconds
next_page_url = "/services/data/v52.0/query/012345"
with requests_mock.Mocker() as m:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,22 @@
from source_salesforce.source import SourceSalesforce


@pytest.fixture(autouse=True)
def time_sleep_mock(mocker):
time_mock = mocker.patch("time.sleep", lambda x: None)
yield time_mock


@pytest.fixture(scope="module")
def bulk_catalog():
with open("unit_tests/bulk_catalog.json") as f:
data = json.loads(f.read())
return ConfiguredAirbyteCatalog.parse_obj(data)


@pytest.fixture(scope="module")
def configured_catalog():
with open("unit_tests/configured_catalog.json") as f:
def rest_catalog():
with open("unit_tests/rest_catalog.json") as f:
data = json.loads(f.read())
return ConfiguredAirbyteCatalog.parse_obj(data)

Expand Down Expand Up @@ -86,5 +99,5 @@ def stream_api_v2(stream_config):
return _stream_api(stream_config, describe_response_data=describe_response_data)


def generate_stream(stream_name, stream_config, stream_api, state=None):
return SourceSalesforce.generate_streams(stream_config, {stream_name: None}, stream_api, state=state)[0]
def generate_stream(stream_name, stream_config, stream_api):
return SourceSalesforce.generate_streams(stream_config, {stream_name: None}, stream_api)[0]
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
from source_salesforce.exceptions import TypeSalesforceException


@pytest.fixture(autouse=True)
def time_sleep_mock(mocker):
time_mock = mocker.patch("time.sleep", lambda x: None)
yield time_mock


@pytest.mark.parametrize(
"streams_criteria,predicted_filtered_streams",
[
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"streams": [
{
"stream": {
"name": "KnowledgeArticle",
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_cursor": true,
"default_cursor_field": ["LastModifiedDate"],
"source_defined_primary_key": [["Id"]]
},
"sync_mode": "incremental",
"destination_sync_mode": "append"
},
{
"stream": {
"name": "AcceptedEventRelation",
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_cursor": true,
"default_cursor_field": ["SystemModstamp"],
"source_defined_primary_key": [["Id"]]
},
"sync_mode": "incremental",
"destination_sync_mode": "append"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
from source_salesforce.streams import BulkIncrementalSalesforceStream


@pytest.fixture(autouse=True)
def time_sleep_mock(mocker):
time_mock = mocker.patch("time.sleep", lambda x: None)
yield time_mock


@pytest.mark.parametrize(
"n_records, first_size, first_peak",
(
Expand Down

0 comments on commit 596a436

Please sign in to comment.