Skip to content

Commit

Permalink
馃悰 Source Amazon Ads: fix fragile polling generator (#8388)
Browse files Browse the repository at this point in the history
  • Loading branch information
monai committed Dec 28, 2021
1 parent 069587b commit 039a1aa
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#

import json
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand Down Expand Up @@ -57,6 +56,28 @@ class ReportInfo:
report_id: str
profile_id: int
record_type: str
status: Status
metric_objects: List[dict]


class RetryableException(Exception):
pass


class ReportGenerationFailure(RetryableException):
pass


class ReportGenerationInProgress(RetryableException):
pass


class ReportStatusFailure(RetryableException):
pass


class ReportInitFailure(RetryableException):
pass


class TooManyRequests(Exception):
Expand All @@ -77,7 +98,7 @@ class ReportStream(BasicAmazonAdsStream, ABC):
# Async report generation time is 15 minutes according to docs:
# https://advertising.amazon.com/API/docs/en-us/get-started/developer-notes
# (Service limits section)
REPORT_WAIT_TIMEOUT = timedelta(minutes=20)
REPORT_WAIT_TIMEOUT = timedelta(minutes=30).total_seconds
# Format used to specify metric generation date over Amazon Ads API.
REPORT_DATE_FORMAT = "%Y%m%d"
cursor_field = "reportDate"
Expand Down Expand Up @@ -118,39 +139,57 @@ def read_records(
# take any action and just return.
return
report_date = stream_slice[self.cursor_field]
report_infos = self._init_and_try_read_records(report_date)

for report_info in report_infos:
for metric_object in report_info.metric_objects:
yield self._model(
profileId=report_info.profile_id,
recordType=report_info.record_type,
reportDate=report_date,
metric=metric_object,
).dict()

@backoff.on_exception(
backoff.expo,
ReportGenerationFailure,
max_tries=5,
)
def _init_and_try_read_records(self, report_date):
report_infos = self._init_reports(report_date)
logger.info(f"Waiting for {len(report_infos)} report(s) to be generated")
# According to Amazon Ads API docs metric generation takes maximum 15
# minutes. But in case reports wont be generated we dont want this stream to
# hung forever. Store timepoint when report generation has started to
# check if it takes to long to break a loop.
start_time_point = datetime.now()
while report_infos and datetime.now() <= start_time_point + self.REPORT_WAIT_TIMEOUT:
completed_reports = []
logger.info(f"Checking report status, {len(report_infos)} report(s) remained")
for report_info in report_infos:
report_status, download_url = self._check_status(report_info)
if report_status == Status.FAILURE:
raise Exception(f"Report for {report_info.profile_id} with {report_info.record_type} type generation failed")
elif report_status == Status.SUCCESS:
metric_objects = self._download_report(report_info, download_url)
for metric_object in metric_objects:
yield self._model(
profileId=report_info.profile_id,
recordType=report_info.record_type,
reportDate=report_date,
metric=metric_object,
).dict()
completed_reports.append(report_info)
for completed_report in completed_reports:
report_infos.remove(completed_report)
if report_infos:
logger.info(f"{len(report_infos)} report(s) remained, taking {self.CHECK_INTERVAL_SECONDS} seconds timeout")
time.sleep(self.CHECK_INTERVAL_SECONDS)
if not report_infos:
logger.info("All reports have been processed")
else:
raise Exception("Not all reports has been processed due to timeout")
self._try_read_records(report_infos)
return report_infos

@backoff.on_exception(
backoff.constant,
RetryableException,
max_time=REPORT_WAIT_TIMEOUT,
)
def _try_read_records(self, report_infos):
incomplete_report_infos = self._incomplete_report_infos(report_infos)

logger.info(f"Checking report status, {len(incomplete_report_infos)} report(s) remaining")
for report_info in incomplete_report_infos:
report_status, download_url = self._check_status(report_info)
report_info.status = report_status

if report_status == Status.FAILURE:
message = f"Report for {report_info.profile_id} with {report_info.record_type} type generation failed"
raise ReportGenerationFailure(message)
elif report_status == Status.SUCCESS:
try:
report_info.metric_objects = self._download_report(report_info, download_url)
except requests.HTTPError as error:
raise ReportGenerationFailure(error)

pending_report_status = [(r.profile_id, r.report_id, r.status) for r in self._incomplete_report_infos(report_infos)]
if len(pending_report_status) > 0:
message = f"Report generation in progress: {repr(pending_report_status)}"
raise ReportGenerationInProgress(message)

def _incomplete_report_infos(self, report_infos):
return [r for r in report_infos if r.status != Status.SUCCESS]

def _generate_model(self):
"""
Expand Down Expand Up @@ -191,7 +230,12 @@ def _check_status(self, report_info: ReportInfo) -> Tuple[Status, str]:
"""
check_endpoint = f"/v2/reports/{report_info.report_id}"
resp = self._send_http_request(urljoin(self._url, check_endpoint), report_info.profile_id)
resp = ReportStatus.parse_raw(resp.text)

try:
resp = ReportStatus.parse_raw(resp.text)
except ValueError as error:
raise ReportStatusFailure(error)

return resp.status, resp.location

@backoff.on_exception(
Expand Down Expand Up @@ -265,6 +309,11 @@ def _get_init_report_body(self, report_date: str, record_type: str, profile) ->
Override to return dict representing body of POST request for initiating report creation.
"""

@backoff.on_exception(
backoff.expo,
ReportInitFailure,
max_tries=5,
)
def _init_reports(self, report_date: str) -> List[ReportInfo]:
"""
Send report generation requests for all profiles and for all record types for specific day.
Expand Down Expand Up @@ -292,8 +341,8 @@ def _init_reports(self, report_date: str) -> List[ReportInfo]:
report_init_body,
)
if response.status_code != HTTPStatus.ACCEPTED:
raise Exception(
f"Unexpected error when registering {record_type}, {self.__class__.__name__} for {profile.profileId} profile: {response.text}"
raise ReportInitFailure(
f"Unexpected HTTP status code {response.status_code} when registering {record_type}, {type(self).__name__} for {profile.profileId} profile: {response.text}"
)

response = ReportInitResponse.parse_raw(response.text)
Expand All @@ -302,6 +351,8 @@ def _init_reports(self, report_date: str) -> List[ReportInfo]:
report_id=response.reportId,
record_type=record_type,
profile_id=profile.profileId,
status=Status.IN_PROGRESS,
metric_objects=[],
)
)
logger.info("Initiated successfully")
Expand All @@ -324,10 +375,16 @@ def _calc_report_generation_date(report_date: str, profile) -> str:
profile_time = report_date.astimezone(profile_tz)
return profile_time.strftime(ReportStream.REPORT_DATE_FORMAT)

@backoff.on_exception(
backoff.expo,
requests.HTTPError,
max_tries=5,
)
def _download_report(self, report_info: ReportInfo, url: str) -> List[dict]:
"""
Download and parse report result
"""
response = self._send_http_request(url, report_info.profile_id)
response.raise_for_status()
raw_string = decompress(response.content).decode("utf")
return json.loads(raw_string)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SponsoredDisplayReportStream,
SponsoredProductsReportStream,
)
from source_amazon_ads.streams.report_streams.report_streams import TooManyRequests
from source_amazon_ads.streams.report_streams.report_streams import ReportGenerationFailure, ReportGenerationInProgress, TooManyRequests

"""
METRIC_RESPONSE is gzip compressed binary representing this string:
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_products_report_stream(test_config):
profiles = make_profiles(profile_type="vendor")

stream = SponsoredProductsReportStream(config, profiles, authenticator=mock.MagicMock())
stream_slice = {"reportDate": "20210725"}
stream_slice = {"reportDate": "20210725", "retry_count": 3}
metrics = [m for m in stream.read_records(SyncMode.incremental, stream_slice=stream_slice)]
assert len(metrics) == METRICS_COUNT * len(stream.metrics_map)

Expand Down Expand Up @@ -180,23 +180,6 @@ def test_brands_video_report_stream(test_config):
assert len(metrics) == METRICS_COUNT * len(stream.metrics_map)


@responses.activate
def test_display_report_stream_report_generation_failure(test_config):
setup_responses(
init_response=REPORT_INIT_RESPONSE,
status_response=REPORT_STATUS_RESPONSE.replace("SUCCESS", "FAILURE"),
metric_response=METRIC_RESPONSE,
)

config = AmazonAdsConfig(**test_config)
profiles = make_profiles()

stream = SponsoredDisplayReportStream(config, profiles, authenticator=mock.MagicMock())
stream_slice = {"reportDate": "20210725"}
with pytest.raises(Exception):
_ = [m for m in stream.read_records(SyncMode.incremental, stream_slice=stream_slice)]


@responses.activate
def test_display_report_stream_init_failure(mocker, test_config):
config = AmazonAdsConfig(**test_config)
Expand Down Expand Up @@ -239,38 +222,82 @@ def test_display_report_stream_init_too_many_requests(mocker, test_config):
assert len(responses.calls) == 5


@pytest.mark.parametrize(
("modifiers", "expected"),
[
(
[
(lambda x: x <= 5, "SUCCESS", None),
],
5,
),
(
[
(lambda x: x > 5, "SUCCESS", None),
],
10,
),
(
[
(lambda x: x > 5, None, "2021-01-02 03:34:05"),
],
ReportGenerationInProgress,
),
(
[
(lambda x: x >= 1 and x <= 5, "FAILURE", None),
(lambda x: x >= 6 and x <= 10, None, "2021-01-02 03:23:05"),
(lambda x: x >= 11, "SUCCESS", "2021-01-02 03:24:06"),
],
15,
),
(
[
(lambda x: True, "FAILURE", None),
(lambda x: x >= 10, None, "2021-01-02 03:34:05"),
(lambda x: x >= 15, None, "2021-01-02 04:04:05"),
(lambda x: x >= 20, None, "2021-01-02 04:34:05"),
(lambda x: x >= 25, None, "2021-01-02 05:04:05"),
(lambda x: x >= 30, None, "2021-01-02 05:34:05"),
],
ReportGenerationFailure,
),
],
)
@responses.activate
def test_display_report_stream_timeout(mocker, test_config):
time_mock = mock.MagicMock()
mocker.patch("time.sleep", time_mock)
def test_display_report_stream_backoff(mocker, test_config, modifiers, expected):
setup_responses(init_response=REPORT_INIT_RESPONSE, metric_response=METRIC_RESPONSE)

with freeze_time("2021-07-30 04:26:08") as frozen_time:
success_cnt = 2
with freeze_time("2021-01-02 03:04:05") as frozen_time:

class StatusCallback:
count: int = 0

def __call__(self, request):
self.count += 1
response = REPORT_STATUS_RESPONSE
if self.count > success_cnt:
response = REPORT_STATUS_RESPONSE.replace("SUCCESS", "IN_PROGRESS")
if self.count > success_cnt + 1:
frozen_time.move_to("2021-07-30 06:26:08")
response = REPORT_STATUS_RESPONSE.replace("SUCCESS", "IN_PROGRESS")

for index, status, time in modifiers:
if index(self.count):
if status:
response = response.replace("IN_PROGRESS", status)
if time:
frozen_time.move_to(time)
return (200, {}, response)

responses.add_callback(
responses.GET, re.compile(r"https://advertising-api.amazon.com/v2/reports/[^/]+$"), callback=StatusCallback()
)
callback = StatusCallback()
responses.add_callback(responses.GET, re.compile(r"https://advertising-api.amazon.com/v2/reports/[^/]+$"), callback=callback)
config = AmazonAdsConfig(**test_config)
profiles = make_profiles()
stream = SponsoredDisplayReportStream(config, profiles, authenticator=mock.MagicMock())
stream_slice = {"reportDate": "20210725"}

with pytest.raises(Exception):
_ = [m for m in stream.read_records(SyncMode.incremental, stream_slice=stream_slice)]
time_mock.assert_called_with(30)
if isinstance(expected, int):
list(stream.read_records(SyncMode.incremental, stream_slice=stream_slice))
assert callback.count == expected
elif issubclass(expected, Exception):
with pytest.raises(expected):
list(stream.read_records(SyncMode.incremental, stream_slice=stream_slice))


@freeze_time("2021-07-30 04:26:08")
Expand Down

0 comments on commit 039a1aa

Please sign in to comment.