diff --git a/openverse_catalog/dags/providers/provider_api_scripts/flickr.py b/openverse_catalog/dags/providers/provider_api_scripts/flickr.py index 70cc4f5a1..f3f357ca6 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/flickr.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/flickr.py @@ -15,7 +15,6 @@ from datetime import datetime, timedelta import lxml.html as html -from airflow.exceptions import AirflowException from airflow.models import Variable from common import constants from common.licenses import get_license_info @@ -71,24 +70,77 @@ def __init__(self, *args, **kwargs): # are built multiple times. Fetch keys and build parameters here so it # is done only once. self.api_key = Variable.get("API_KEY_FLICKR") - self.license_param = ",".join(LICENSE_INFO.keys()) + self.default_license_param = ",".join(LICENSE_INFO.keys()) # Keeps track of number of requests made, so we can see how close we are # to hitting rate limits. self.requests_count = 0 + # Keep track of batches containing more than the max_unique_records + self.large_batches = [] + # When we encounter a batch containing more than the max_unique_records, we can + # either return early without attempting to ingest records, or we can ingest up + # to the max_unique_records count. This flag is used to control this behavior + # during different stages of the ingestion process. + self.process_large_batch = False + + def ingest_records(self, **kwargs): + """ + Ingest records, handling large batches. + + The Flickr API returns a maximum of 4,000 unique records per query; after that, + it returns only duplicates. Therefore in order to ingest as many unique records + as possible, we must attempt to break ingestion into queries that contain fewer + than the max_unique_records each. + + First, we use the TimeDelineatedProviderDataIngester to break queries down into + smaller intervals of time throughout the ingestion day. However, this only has + granularity down to about 5 minute intervals. If a 5-minute interval contains + more than the max unique records, we then try splitting the interval up by + license, making a separate query for each license type. + + If a 5-minute interval contains more than max_unique_records for a single + license type, we accept that we cannot produce a query small enough to ingest + all the unique records over this time period. We will process up to the + max_unique_records count and then move on to the next batch. + """ + # Perform ingestion as normal, splitting requests into time-slices of at most + # 5 minutes. When a batch is encountered which contains more than + # max_unique_records, it is skipped and added to the `large_batches` list for + # later processing. + super().ingest_records(**kwargs) + logger.info("Completed initial ingestion.") + + # If we encounter a large batch at this stage of ingestion, we cannot break + # the batch down any further so we want to attempt to process as much as + # possible. + self.process_large_batch = True + + for start_ts, end_ts in self.large_batches: + # For each large batch, ingest records for that interval one license + # type at a time. + for license in LICENSE_INFO.keys(): + super().ingest_records_for_timestamp_pair( + start_ts=start_ts, end_ts=end_ts, license=license + ) + logger.info("Completed large batch processing by license type.") + def get_next_query_params(self, prev_query_params, **kwargs): if not prev_query_params: # Initial request, return default params start_timestamp = kwargs.get("start_ts") end_timestamp = kwargs.get("end_ts") + # license will be available in the params if we're dealing + # with a large batch. If not, fall back to all licenses + license = kwargs.get("license", self.default_license_param) + return { "min_upload_date": start_timestamp, "max_upload_date": end_timestamp, "page": 0, "api_key": self.api_key, - "license": self.license_param, + "license": license, "per_page": self.batch_limit, "method": "flickr.photos.search", "media": "photos", @@ -129,6 +181,32 @@ def get_batch_data(self, response_json): if response_json is None or response_json.get("stat") != "ok": return None + + # Detect if this batch has more than the max_unique_records. + detected_count = self.get_record_count_from_response(response_json) + if detected_count > self.max_unique_records: + logger.error( + f"{detected_count} records retrieved, but there is a" + f" limit of {self.max_unique_records}." + ) + + if not self.process_large_batch: + # We don't want to attempt to process this batch. Return an empty + # batch now, and we will try again later, splitting the batch up by + # license type. + self.large_batches.append(self.current_timestamp_pair) + return None + + # If we do want to process large batches, we should only ingest up to + # the `max_unique_records` and then return early, as all records + # retrieved afterward will be duplicates. + page_number = response_json.get("photos", {}).get("page", 1) + if self.batch_limit * page_number > self.max_unique_records: + # We have already ingested up to `max_unique_records` for this + # batch, so we should only be getting duplicates after this. + return None + + # Return data for this batch return response_json.get("photos", {}).get("photo") def get_record_count_from_response(self, response_json) -> int: @@ -258,22 +336,6 @@ def _get_category(image_data): return ImageCategory.PHOTOGRAPH return None - def get_should_continue(self, response_json): - # Call the parent method in order to update the fetched_count - should_continue = super().get_should_continue(response_json) - - # Return early if more than the maximum unique records have been ingested. - # This could happen if we did not break the ingestion down into - # small enough divisions. - if self.fetched_count > self.max_unique_records: - raise AirflowException( - f"{self.fetched_count} records retrieved, but there is a" - f" limit of {self.max_records}. Consider increasing the" - " number of divisions." - ) - - return should_continue - def main(date): ingester = FlickrDataIngester(date=date) diff --git a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py index 5339e6580..19b83285b 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/provider_data_ingester.py @@ -37,7 +37,7 @@ class IngestionError(Exception): def __init__(self, error, traceback, query_params): self.error = error self.traceback = traceback - self.query_params = json.dumps(query_params) + self.query_params = json.dumps(query_params, default=str) def __str__(self): # Append query_param info to error message diff --git a/openverse_catalog/dags/providers/provider_api_scripts/time_delineated_provider_data_ingester.py b/openverse_catalog/dags/providers/provider_api_scripts/time_delineated_provider_data_ingester.py index 9e54ee503..6f379e624 100644 --- a/openverse_catalog/dags/providers/provider_api_scripts/time_delineated_provider_data_ingester.py +++ b/openverse_catalog/dags/providers/provider_api_scripts/time_delineated_provider_data_ingester.py @@ -2,7 +2,6 @@ from abc import abstractmethod from datetime import datetime, timedelta, timezone -from airflow.exceptions import AirflowException from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester @@ -60,6 +59,12 @@ def __init__(self, *args, **kwargs): # specifically in this iteration. self.fetched_count = 0 + # Keep track of our ts pairs + self.timestamp_pairs = [] + + # Keep track of the current ts pair + self.current_timestamp_pair = () + @staticmethod def format_ts(timestamp): return timestamp.isoformat().replace("+00:00", "Z") @@ -153,6 +158,7 @@ def _get_timestamp_pairs(self, **kwargs): # contain data. Hours that contain more data get divided into a larger number of # portions. hour_slices = self._get_timestamp_query_params_list(start_ts, end_ts, 24) + for (start_hour, end_hour) in hour_slices: # Get the number of records in this hour interval record_count = self._get_record_count(start_hour, end_hour, **kwargs) @@ -190,18 +196,27 @@ def _get_timestamp_pairs(self, **kwargs): return pairs_list def ingest_records(self, **kwargs) -> None: - timestamp_pairs = self._get_timestamp_pairs(**kwargs) - if timestamp_pairs: - logger.info(f"{len(timestamp_pairs)} timestamp pairs generated.") + self.timestamp_pairs = self._get_timestamp_pairs(**kwargs) + if self.timestamp_pairs: + logger.info(f"{len(self.timestamp_pairs)} timestamp pairs generated.") # Run ingestion for each timestamp pair - for start_ts, end_ts in timestamp_pairs: - # Reset counts before we start - self.new_iteration = True - self.fetched_count = 0 + for start_ts, end_ts in self.timestamp_pairs: + self.ingest_records_for_timestamp_pair(start_ts, end_ts, **kwargs) + + def ingest_records_for_timestamp_pair( + self, start_ts: datetime, end_ts: datetime, **kwargs + ): + # Update `current_timestamp_pair` to keep track of what we are processing. + self.current_timestamp_pair = (start_ts, end_ts) + + # Reset counts + self.new_iteration = True + self.fetched_count = 0 - logger.info(f"Ingesting data for start: {start_ts}, end: {end_ts}") - super().ingest_records(start_ts=start_ts, end_ts=end_ts, **kwargs) + # Run ingestion for the given parameters + logger.info(f"Ingesting data for start: {start_ts}, end: {end_ts}") + super().ingest_records(start_ts=start_ts, end_ts=end_ts, **kwargs) def get_should_continue(self, response_json) -> bool: """ @@ -235,9 +250,9 @@ def get_should_continue(self, response_json) -> bool: " been fetched. Consider reducing the ingestion interval." ) if self.should_raise_error: - raise AirflowException(error_message) + raise Exception(error_message) else: - logger.error(error_message) + logger.info(error_message) # If `should_raise_error` was enabled, the error is raised and ingestion # halted. If not, we want to log but continue ingesting. diff --git a/tests/dags/providers/provider_api_scripts/test_flickr.py b/tests/dags/providers/provider_api_scripts/test_flickr.py index 7167dec32..ab85e2c3a 100644 --- a/tests/dags/providers/provider_api_scripts/test_flickr.py +++ b/tests/dags/providers/provider_api_scripts/test_flickr.py @@ -1,9 +1,9 @@ import json import os +from datetime import datetime from unittest import mock import pytest -from airflow.exceptions import AirflowException from common.licenses import LicenseInfo from providers.provider_api_scripts.flickr import FlickrDataIngester @@ -71,35 +71,6 @@ def test_get_record_count_from_response(): assert count == 30 -@pytest.mark.parametrize( - "fetched_count, super_should_continue, expected_should_continue", - ( - (1000, True, True), - # Return False if super().get_should_continue() is False - (1000, False, False), - (4000, True, True), - # Raise exception if fetched_count exceeds max_unique_records - pytest.param( - 4001, - True, - False, - marks=pytest.mark.raises(exception=AirflowException), - ), - ), -) -def test_get_should_continue( - fetched_count, super_should_continue, expected_should_continue -): - with mock.patch( - "providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.get_should_continue", - return_value=super_should_continue, - ): - ingester = FlickrDataIngester(date=FROZEN_DATE) - ingester.fetched_count = fetched_count - - assert ingester.get_should_continue({}) == expected_should_continue - - @pytest.mark.parametrize( "response_json, expected_response", [ @@ -131,6 +102,42 @@ def test_get_batch_data(response_json, expected_response): assert actual_response == expected_response +@pytest.mark.parametrize( + "process_large_batch, page_number, expected_response, expected_large_batches", + [ + # process_large_batch is False: always return None and add batch to the list + (False, 1, None, 1), + (False, 20, None, 1), + # process_large_batch is True: never add batch to the list + # Fewer than max_unique_records have been processed + (True, 1, _get_resource_json("flickr_example_photo_list.json"), 0), + (True, 10, _get_resource_json("flickr_example_photo_list.json"), 0), + # More than max_unique_records have been processed + (True, 11, None, 0), + (True, 20, None, 0), + ], +) +def test_get_batch_data_when_detected_count_exceeds_max_unique_records( + process_large_batch, page_number, expected_response, expected_large_batches +): + # Hard code the batch_limit and max_unique_records for the test + ingester = FlickrDataIngester(date=FROZEN_DATE) + # This means you should be able to get 100/10 = 10 pages of results before + # you exceed the max_unique_records. + ingester.batch_limit = 10 + ingester.max_unique_records = 100 + ingester.process_large_batch = process_large_batch + + response_json = _get_resource_json("flickr_example_pretty.json") + response_json["photos"]["page"] = page_number + response_json["photos"]["total"] = 200 # More than max unique records + + actual_response = ingester.get_batch_data(response_json) + assert actual_response == expected_response + + assert len(ingester.large_batches) == expected_large_batches + + def test_get_record_data(): image_data = _get_resource_json("image_data_complete_example.json") actual_data = flickr.get_record_data(image_data) @@ -342,3 +349,40 @@ def test_get_record_data_with_sub_provider(): "category": None, } assert actual_data == expected_data + + +def test_ingest_records(): + # Test a 'normal' run where no large batches are detected during ingestion. + with ( + mock.patch( + "providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records" + ), + mock.patch( + "providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records_for_timestamp_pair" + ) as ingest_for_pair_mock, + ): + flickr.ingest_records() + # No additional calls to ingest_records_for_timestamp_pair were made + assert not ingest_for_pair_mock.called + + +def test_ingest_records_when_large_batches_detected(): + ingester = FlickrDataIngester(date=FROZEN_DATE) + with ( + mock.patch( + "providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records" + ), + mock.patch( + "providers.provider_api_scripts.time_delineated_provider_data_ingester.TimeDelineatedProviderDataIngester.ingest_records_for_timestamp_pair" + ) as ingest_for_pair_mock, + ): + # Set large_batches to include one timestamp pair + mock_start = datetime(2020, 1, 1, 1, 0, 0) + mock_end = datetime(2020, 1, 1, 2, 0, 0) + ingester.large_batches = [ + (mock_start, mock_end), + ] + + ingester.ingest_records() + # An additional call made to ingest_records_for_timestamp_pair for each license type + assert ingest_for_pair_mock.call_count == 8 diff --git a/tests/dags/providers/provider_api_scripts/test_time_delineated_provider_data_ingester.py b/tests/dags/providers/provider_api_scripts/test_time_delineated_provider_data_ingester.py index a71ee49ae..b2de058fb 100644 --- a/tests/dags/providers/provider_api_scripts/test_time_delineated_provider_data_ingester.py +++ b/tests/dags/providers/provider_api_scripts/test_time_delineated_provider_data_ingester.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch import pytest -from airflow.exceptions import AirflowException from tests.dags.providers.provider_api_scripts.resources.provider_data_ingester.mock_provider_data_ingester import ( MAX_RECORDS, @@ -205,7 +204,7 @@ def test_ts_pairs_and_kwargs_are_available_in_get_next_query_params(): def test_ingest_records_raises_error_if_the_total_count_has_been_exceeded(): - # Test that `ingest_records` raises an AirflowException if the external + # Test that `ingest_records` raises an Exception if the external # API continues returning data in excess of the stated `resultCount` # (https://github.com/WordPress/openverse-catalog/pull/934) with ( @@ -238,7 +237,7 @@ def test_ingest_records_raises_error_if_the_total_count_has_been_exceeded(): ) # Assert that attempting to ingest records raises an exception when # `should_raise_error` is enabled - with (pytest.raises(AirflowException, match=expected_error_string)): + with (pytest.raises(Exception, match=expected_error_string)): ingester.ingest_records() # get_mock should have been called 4 times, twice for each batch (once in `get_batch`