Skip to content
This repository has been archived by the owner on Aug 4, 2023. It is now read-only.

Respect ingestion limit in process_batch #818

Merged
merged 1 commit into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def process_batch(self, media_batch):
Returns the total count of records ingested up to this point, for all
media types.
"""
record_count = 0
processed_count = 0

for data in media_batch:
record_data = self.get_record_data(data)
Expand All @@ -375,9 +375,13 @@ def process_batch(self, media_batch):
# Add the record to the correct store
store = self.media_stores[media_type]
store.add_item(**record)
record_count += 1
processed_count += 1

return record_count
if self.limit and (self.record_count + processed_count) >= self.limit:
logger.info("Ingestion limit has been reached. Halting processing.")
return processed_count

return processed_count

@abstractmethod
def get_media_type(self, record: dict) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ def test_process_batch_handles_list_of_records():
assert image_store_mock.call_count == 1


def test_process_batch_halts_processing_after_reaching_ingestion_limit():
# Set up an ingester with an ingestion limit of 1
ingester = MockProviderDataIngester()
ingester.limit = 1

with (
patch.object(audio_store, "add_item"),
patch.object(image_store, "add_item"),
patch.object(ingester, "get_record_data") as get_record_data_mock,
):

# Mock `get_record_data` to return a list of 2 records
get_record_data_mock.return_value = MOCK_RECORD_DATA_LIST
record_count = ingester.process_batch(EXPECTED_BATCH_DATA)

# Only the first record was added, and then ingestion stopped
assert record_count == 1


def test_ingest_records():
with (
patch.object(ingester, "get_batch") as get_batch_mock,
Expand Down