diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index ee8ad37a0c..639ef49ce6 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -348,7 +348,12 @@ async def _process_batch( unprocessed_requests_unique_keys = {request.unique_key for request in response.unprocessed_requests} retry_batch = [request for request in batch if request.unique_key in unprocessed_requests_unique_keys] await asyncio.sleep((base_retry_wait * attempt).total_seconds()) - await self._process_batch(retry_batch, base_retry_wait=base_retry_wait, attempt=attempt + 1) + await self._process_batch( + retry_batch, + base_retry_wait=base_retry_wait, + attempt=attempt + 1, + forefront=forefront, + ) request_count = len(batch) - len(response.unprocessed_requests) diff --git a/tests/unit/storages/test_request_queue.py b/tests/unit/storages/test_request_queue.py index 86e4b028ff..09ce769d9e 100644 --- a/tests/unit/storages/test_request_queue.py +++ b/tests/unit/storages/test_request_queue.py @@ -9,11 +9,12 @@ from crawlee import Request, service_locator from crawlee.configuration import Configuration from crawlee.storage_clients import MemoryStorageClient, StorageClient +from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, UnprocessedRequest from crawlee.storages import RequestQueue from crawlee.storages._storage_instance_manager import StorageInstanceManager if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Sequence from crawlee.storage_clients import StorageClient @@ -259,6 +260,56 @@ async def test_add_requests_with_forefront(rq: RequestQueue) -> None: assert next_request.url == 'https://example.com/priority' +@pytest.mark.parametrize('forefront', [True, False]) +async def test_add_requests_retry_preserves_forefront( + monkeypatch: pytest.MonkeyPatch, + *, + forefront: bool, +) -> None: + """Regression test: when ``add_batch_of_requests`` returns unprocessed requests, the retry must preserve the + original `forefront` value rather than silently falling back to the parameter default.""" + rq = await RequestQueue.open(storage_client=MemoryStorageClient()) + forefront_calls: list[bool] = [] + + async def patched_add_batch( + requests: Sequence[Request], + *, + forefront: bool = False, + ) -> AddRequestsResponse: + forefront_calls.append(forefront) + if len(forefront_calls) == 1: + return AddRequestsResponse( + processed_requests=[], + unprocessed_requests=[UnprocessedRequest(unique_key=r.unique_key, url=r.url) for r in requests], + ) + return AddRequestsResponse( + processed_requests=[ + ProcessedRequest( + unique_key=r.unique_key, + was_already_present=False, + was_already_handled=False, + ) + for r in requests + ], + unprocessed_requests=[], + ) + + monkeypatch.setattr(rq._client, 'add_batch_of_requests', patched_add_batch) + + try: + await rq.add_requests( + ['https://example.com/a', 'https://example.com/b'], + forefront=forefront, + wait_time_between_batches=timedelta(seconds=0), + ) + finally: + await rq.drop() + + assert forefront_calls == [forefront, forefront], ( + f'retry must propagate the original forefront={forefront} flag, got: {forefront_calls}' + ) + + async def test_add_requests_mixed_forefront(rq: RequestQueue) -> None: """Test the ordering when adding requests with mixed forefront values.""" # Add normal requests