Skip to content
This repository has been archived by the owner on Aug 4, 2023. It is now read-only.

Commit

Permalink
Update Flickr large batch handling (#1047)
Browse files Browse the repository at this point in the history
* Prevent serialization errors when query params include datetimes

* Refactor TimeDelineated base class to allow calling ingest_records on a single ts pair

* Update Flickr to split large batches by license and ingest as many records as possible
  • Loading branch information
stacimc committed Mar 27, 2023
1 parent 5ff4439 commit 7d0ce7f
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 65 deletions.
100 changes: 81 additions & 19 deletions openverse_catalog/dags/providers/provider_api_scripts/flickr.py
Expand Up @@ -15,7 +15,6 @@
from datetime import datetime, timedelta

import lxml.html as html
from airflow.exceptions import AirflowException
from airflow.models import Variable
from common import constants
from common.licenses import get_license_info
Expand Down Expand Up @@ -71,24 +70,77 @@ def __init__(self, *args, **kwargs):
# are built multiple times. Fetch keys and build parameters here so it
# is done only once.
self.api_key = Variable.get("API_KEY_FLICKR")
self.license_param = ",".join(LICENSE_INFO.keys())
self.default_license_param = ",".join(LICENSE_INFO.keys())

# Keeps track of number of requests made, so we can see how close we are
# to hitting rate limits.
self.requests_count = 0

# Keep track of batches containing more than the max_unique_records
self.large_batches = []
# When we encounter a batch containing more than the max_unique_records, we can
# either return early without attempting to ingest records, or we can ingest up
# to the max_unique_records count. This flag is used to control this behavior
# during different stages of the ingestion process.
self.process_large_batch = False

def ingest_records(self, **kwargs):
"""
Ingest records, handling large batches.
The Flickr API returns a maximum of 4,000 unique records per query; after that,
it returns only duplicates. Therefore in order to ingest as many unique records
as possible, we must attempt to break ingestion into queries that contain fewer
than the max_unique_records each.
First, we use the TimeDelineatedProviderDataIngester to break queries down into
smaller intervals of time throughout the ingestion day. However, this only has
granularity down to about 5 minute intervals. If a 5-minute interval contains
more than the max unique records, we then try splitting the interval up by
license, making a separate query for each license type.
If a 5-minute interval contains more than max_unique_records for a single
license type, we accept that we cannot produce a query small enough to ingest
all the unique records over this time period. We will process up to the
max_unique_records count and then move on to the next batch.
"""
# Perform ingestion as normal, splitting requests into time-slices of at most
# 5 minutes. When a batch is encountered which contains more than
# max_unique_records, it is skipped and added to the `large_batches` list for
# later processing.
super().ingest_records(**kwargs)
logger.info("Completed initial ingestion.")

# If we encounter a large batch at this stage of ingestion, we cannot break
# the batch down any further so we want to attempt to process as much as
# possible.
self.process_large_batch = True

for start_ts, end_ts in self.large_batches:
# For each large batch, ingest records for that interval one license
# type at a time.
for license in LICENSE_INFO.keys():
super().ingest_records_for_timestamp_pair(
start_ts=start_ts, end_ts=end_ts, license=license
)
logger.info("Completed large batch processing by license type.")

def get_next_query_params(self, prev_query_params, **kwargs):
if not prev_query_params:
# Initial request, return default params
start_timestamp = kwargs.get("start_ts")
end_timestamp = kwargs.get("end_ts")

# license will be available in the params if we're dealing
# with a large batch. If not, fall back to all licenses
license = kwargs.get("license", self.default_license_param)

return {
"min_upload_date": start_timestamp,
"max_upload_date": end_timestamp,
"page": 0,
"api_key": self.api_key,
"license": self.license_param,
"license": license,
"per_page": self.batch_limit,
"method": "flickr.photos.search",
"media": "photos",
Expand Down Expand Up @@ -129,6 +181,32 @@ def get_batch_data(self, response_json):

if response_json is None or response_json.get("stat") != "ok":
return None

# Detect if this batch has more than the max_unique_records.
detected_count = self.get_record_count_from_response(response_json)
if detected_count > self.max_unique_records:
logger.error(
f"{detected_count} records retrieved, but there is a"
f" limit of {self.max_unique_records}."
)

if not self.process_large_batch:
# We don't want to attempt to process this batch. Return an empty
# batch now, and we will try again later, splitting the batch up by
# license type.
self.large_batches.append(self.current_timestamp_pair)
return None

# If we do want to process large batches, we should only ingest up to
# the `max_unique_records` and then return early, as all records
# retrieved afterward will be duplicates.
page_number = response_json.get("photos", {}).get("page", 1)
if self.batch_limit * page_number > self.max_unique_records:
# We have already ingested up to `max_unique_records` for this
# batch, so we should only be getting duplicates after this.
return None

# Return data for this batch
return response_json.get("photos", {}).get("photo")

def get_record_count_from_response(self, response_json) -> int:
Expand Down Expand Up @@ -258,22 +336,6 @@ def _get_category(image_data):
return ImageCategory.PHOTOGRAPH
return None

def get_should_continue(self, response_json):
# Call the parent method in order to update the fetched_count
should_continue = super().get_should_continue(response_json)

# Return early if more than the maximum unique records have been ingested.
# This could happen if we did not break the ingestion down into
# small enough divisions.
if self.fetched_count > self.max_unique_records:
raise AirflowException(
f"{self.fetched_count} records retrieved, but there is a"
f" limit of {self.max_records}. Consider increasing the"
" number of divisions."
)

return should_continue


def main(date):
ingester = FlickrDataIngester(date=date)
Expand Down
Expand Up @@ -37,7 +37,7 @@ class IngestionError(Exception):
def __init__(self, error, traceback, query_params):
self.error = error
self.traceback = traceback
self.query_params = json.dumps(query_params)
self.query_params = json.dumps(query_params, default=str)

def __str__(self):
# Append query_param info to error message
Expand Down
Expand Up @@ -2,7 +2,6 @@
from abc import abstractmethod
from datetime import datetime, timedelta, timezone

from airflow.exceptions import AirflowException
from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester


Expand Down Expand Up @@ -60,6 +59,12 @@ def __init__(self, *args, **kwargs):
# specifically in this iteration.
self.fetched_count = 0

# Keep track of our ts pairs
self.timestamp_pairs = []

# Keep track of the current ts pair
self.current_timestamp_pair = ()

@staticmethod
def format_ts(timestamp):
return timestamp.isoformat().replace("+00:00", "Z")
Expand Down Expand Up @@ -153,6 +158,7 @@ def _get_timestamp_pairs(self, **kwargs):
# contain data. Hours that contain more data get divided into a larger number of
# portions.
hour_slices = self._get_timestamp_query_params_list(start_ts, end_ts, 24)

for (start_hour, end_hour) in hour_slices:
# Get the number of records in this hour interval
record_count = self._get_record_count(start_hour, end_hour, **kwargs)
Expand Down Expand Up @@ -190,18 +196,27 @@ def _get_timestamp_pairs(self, **kwargs):
return pairs_list

def ingest_records(self, **kwargs) -> None:
timestamp_pairs = self._get_timestamp_pairs(**kwargs)
if timestamp_pairs:
logger.info(f"{len(timestamp_pairs)} timestamp pairs generated.")
self.timestamp_pairs = self._get_timestamp_pairs(**kwargs)
if self.timestamp_pairs:
logger.info(f"{len(self.timestamp_pairs)} timestamp pairs generated.")

# Run ingestion for each timestamp pair
for start_ts, end_ts in timestamp_pairs:
# Reset counts before we start
self.new_iteration = True
self.fetched_count = 0
for start_ts, end_ts in self.timestamp_pairs:
self.ingest_records_for_timestamp_pair(start_ts, end_ts, **kwargs)

def ingest_records_for_timestamp_pair(
self, start_ts: datetime, end_ts: datetime, **kwargs
):
# Update `current_timestamp_pair` to keep track of what we are processing.
self.current_timestamp_pair = (start_ts, end_ts)

# Reset counts
self.new_iteration = True
self.fetched_count = 0

logger.info(f"Ingesting data for start: {start_ts}, end: {end_ts}")
super().ingest_records(start_ts=start_ts, end_ts=end_ts, **kwargs)
# Run ingestion for the given parameters
logger.info(f"Ingesting data for start: {start_ts}, end: {end_ts}")
super().ingest_records(start_ts=start_ts, end_ts=end_ts, **kwargs)

def get_should_continue(self, response_json) -> bool:
"""
Expand Down Expand Up @@ -235,9 +250,9 @@ def get_should_continue(self, response_json) -> bool:
" been fetched. Consider reducing the ingestion interval."
)
if self.should_raise_error:
raise AirflowException(error_message)
raise Exception(error_message)
else:
logger.error(error_message)
logger.info(error_message)

# If `should_raise_error` was enabled, the error is raised and ingestion
# halted. If not, we want to log but continue ingesting.
Expand Down
104 changes: 74 additions & 30 deletions tests/dags/providers/provider_api_scripts/test_flickr.py
@@ -1,9 +1,9 @@
import json
import os
from datetime import datetime
from unittest import mock

import pytest
from airflow.exceptions import AirflowException
from common.licenses import LicenseInfo
from providers.provider_api_scripts.flickr import FlickrDataIngester

Expand Down Expand Up @@ -71,35 +71,6 @@ def test_get_record_count_from_response():
assert count == 30


@pytest.mark.parametrize(
"fetched_count, super_should_continue, expected_should_continue",
(
(1000, True, True),
# Return False if super().get_should_continue() is False
(1000, False, False),
(4000, True, True),
# Raise exception if fetched_count exceeds max_unique_records
pytest.param(
4001,
True,
False,
marks=pytest.mark.raises(exception=AirflowException),
),
),
)
def test_get_should_continue(
fetched_count, super_should_continue, expected_should_continue
):
with mock.patch(
"providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.get_should_continue",
return_value=super_should_continue,
):
ingester = FlickrDataIngester(date=FROZEN_DATE)
ingester.fetched_count = fetched_count

assert ingester.get_should_continue({}) == expected_should_continue


@pytest.mark.parametrize(
"response_json, expected_response",
[
Expand Down Expand Up @@ -131,6 +102,42 @@ def test_get_batch_data(response_json, expected_response):
assert actual_response == expected_response


@pytest.mark.parametrize(
"process_large_batch, page_number, expected_response, expected_large_batches",
[
# process_large_batch is False: always return None and add batch to the list
(False, 1, None, 1),
(False, 20, None, 1),
# process_large_batch is True: never add batch to the list
# Fewer than max_unique_records have been processed
(True, 1, _get_resource_json("flickr_example_photo_list.json"), 0),
(True, 10, _get_resource_json("flickr_example_photo_list.json"), 0),
# More than max_unique_records have been processed
(True, 11, None, 0),
(True, 20, None, 0),
],
)
def test_get_batch_data_when_detected_count_exceeds_max_unique_records(
process_large_batch, page_number, expected_response, expected_large_batches
):
# Hard code the batch_limit and max_unique_records for the test
ingester = FlickrDataIngester(date=FROZEN_DATE)
# This means you should be able to get 100/10 = 10 pages of results before
# you exceed the max_unique_records.
ingester.batch_limit = 10
ingester.max_unique_records = 100
ingester.process_large_batch = process_large_batch

response_json = _get_resource_json("flickr_example_pretty.json")
response_json["photos"]["page"] = page_number
response_json["photos"]["total"] = 200 # More than max unique records

actual_response = ingester.get_batch_data(response_json)
assert actual_response == expected_response

assert len(ingester.large_batches) == expected_large_batches


def test_get_record_data():
image_data = _get_resource_json("image_data_complete_example.json")
actual_data = flickr.get_record_data(image_data)
Expand Down Expand Up @@ -342,3 +349,40 @@ def test_get_record_data_with_sub_provider():
"category": None,
}
assert actual_data == expected_data


def test_ingest_records():
# Test a 'normal' run where no large batches are detected during ingestion.
with (
mock.patch(
"providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records"
),
mock.patch(
"providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records_for_timestamp_pair"
) as ingest_for_pair_mock,
):
flickr.ingest_records()
# No additional calls to ingest_records_for_timestamp_pair were made
assert not ingest_for_pair_mock.called


def test_ingest_records_when_large_batches_detected():
ingester = FlickrDataIngester(date=FROZEN_DATE)
with (
mock.patch(
"providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records"
),
mock.patch(
"providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records_for_timestamp_pair"
) as ingest_for_pair_mock,
):
# Set large_batches to include one timestamp pair
mock_start = datetime(2020, 1, 1, 1, 0, 0)
mock_end = datetime(2020, 1, 1, 2, 0, 0)
ingester.large_batches = [
(mock_start, mock_end),
]

ingester.ingest_records()
# An additional call made to ingest_records_for_timestamp_pair for each license type
assert ingest_for_pair_mock.call_count == 8
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import MagicMock, call, patch

import pytest
from airflow.exceptions import AirflowException

from tests.dags.providers.provider_api_scripts.resources.provider_data_ingester.mock_provider_data_ingester import (
MAX_RECORDS,
Expand Down Expand Up @@ -205,7 +204,7 @@ def test_ts_pairs_and_kwargs_are_available_in_get_next_query_params():


def test_ingest_records_raises_error_if_the_total_count_has_been_exceeded():
# Test that `ingest_records` raises an AirflowException if the external
# Test that `ingest_records` raises an Exception if the external
# API continues returning data in excess of the stated `resultCount`
# (https://github.com/WordPress/openverse-catalog/pull/934)
with (
Expand Down Expand Up @@ -238,7 +237,7 @@ def test_ingest_records_raises_error_if_the_total_count_has_been_exceeded():
)
# Assert that attempting to ingest records raises an exception when
# `should_raise_error` is enabled
with (pytest.raises(AirflowException, match=expected_error_string)):
with (pytest.raises(Exception, match=expected_error_string)):
ingester.ingest_records()

# get_mock should have been called 4 times, twice for each batch (once in `get_batch`
Expand Down

0 comments on commit 7d0ce7f

Please sign in to comment.