diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index 008a7fcf6a..da11adae5d 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -15,7 +15,7 @@ import re from collections.abc import Callable, Coroutine, Sequence - from typing_extensions import NotRequired, Required, Unpack + from typing_extensions import NotRequired, Required, Self, Unpack from crawlee import Glob, Request from crawlee._request import RequestOptions @@ -643,6 +643,25 @@ def __hash__(self) -> int: """Return hash of the context. Each context is considered unique.""" return id(self) + def create_modified_copy( + self, + push_data: PushDataFunction | None = None, + add_requests: AddRequestsFunction | None = None, + get_key_value_store: GetKeyValueStoreFromRequestHandlerFunction | None = None, + ) -> Self: + """Create a modified copy of the crawling context with specified changes.""" + original_fields = {field.name: getattr(self, field.name) for field in dataclasses.fields(self)} + modified_fields = { + key: value + for key, value in { + 'push_data': push_data, + 'add_requests': add_requests, + 'get_key_value_store': get_key_value_store, + }.items() + if value + } + return self.__class__(**{**original_fields, **modified_fields}) + class GetDataKwargs(TypedDict): """Keyword arguments for dataset's `get_data` method.""" diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 74d2aaff13..67363b6aa8 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import functools import logging import signal import sys @@ -14,7 +15,7 @@ from datetime import timedelta from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Literal, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast from urllib.parse import ParseResult, urlparse from weakref import WeakKeyDictionary @@ -96,6 +97,9 @@ TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext) TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) TRequestIterator = TypeVar('TRequestIterator', str, Request) +TParams = ParamSpec('TParams') +T = TypeVar('T') + ErrorHandler = Callable[[TCrawlingContext, Exception], Awaitable[Request | None]] FailedRequestHandler = Callable[[TCrawlingContext, Exception], Awaitable[None]] SkippedRequestCallback = Callable[[str, SkippedReason], Awaitable[None]] @@ -520,6 +524,24 @@ def stop(self, reason: str = 'Stop was called externally.') -> None: self._logger.info(f'Crawler.stop() was called with following reason: {reason}.') self._unexpected_stop = True + def _wrap_handler_with_error_context( + self, handler: Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]] + ) -> Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]]: + """Decorate error handlers to make their context helpers usable.""" + + @functools.wraps(handler) + async def wrapped_handler(context: TCrawlingContext | BasicCrawlingContext, exception: Exception) -> T: + # Original context helpers that are from `RequestHandlerRunResult` will not be commited as the request + # failed. Modified context provides context helpers with direct access to the storages. + error_context = context.create_modified_copy( + push_data=self._push_data, + get_key_value_store=self.get_key_value_store, + add_requests=functools.partial(self._add_requests, context), + ) + return await handler(error_context, exception) + + return wrapped_handler + def _stop_if_max_requests_count_exceeded(self) -> None: """Call `stop` when the maximum number of requests to crawl has been reached.""" if self._max_requests_per_crawl is None: @@ -618,7 +640,7 @@ def error_handler( The error handler is invoked after a request handler error occurs and before a retry attempt. """ - self._error_handler = handler + self._error_handler = self._wrap_handler_with_error_context(handler) return handler def failed_request_handler( @@ -628,7 +650,7 @@ def failed_request_handler( The failed request handler is invoked when a request has failed all retry attempts. """ - self._failed_request_handler = handler + self._failed_request_handler = self._wrap_handler_with_error_context(handler) return handler def on_skipped_request(self, callback: SkippedRequestCallback) -> SkippedRequestCallback: @@ -1256,52 +1278,46 @@ def _convert_url_to_request_iterator(self, urls: Sequence[str | Request], base_u else: yield Request.from_url(url) - async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None: - """Commit request handler result for the input `context`. Result is taken from `_context_result_map`.""" - result = self._context_result_map[context] - - base_request_manager = await self.get_request_manager() - - origin = context.request.loaded_url or context.request.url - - for add_requests_call in result.add_requests_calls: - rq_id = add_requests_call.get('rq_id') - rq_name = add_requests_call.get('rq_name') - rq_alias = add_requests_call.get('rq_alias') - specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None) - if specified_params > 1: - raise ValueError('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.') - if rq_id or rq_name or rq_alias: - request_manager: RequestManager | RequestQueue = await RequestQueue.open( - id=rq_id, - name=rq_name, - alias=rq_alias, - storage_client=self._service_locator.get_storage_client(), - configuration=self._service_locator.get_configuration(), - ) - else: - request_manager = base_request_manager - - requests = list[Request]() - - base_url = url if (url := add_requests_call.get('base_url')) else origin - - requests_iterator = self._convert_url_to_request_iterator(add_requests_call['requests'], base_url) + async def _add_requests( + self, + context: BasicCrawlingContext, + requests: Sequence[str | Request], + rq_id: str | None = None, + rq_name: str | None = None, + rq_alias: str | None = None, + **kwargs: Unpack[EnqueueLinksKwargs], + ) -> None: + """Add requests method aware of the crawling context.""" + if rq_id or rq_name or rq_alias: + request_manager: RequestManager = await RequestQueue.open( + id=rq_id, + name=rq_name, + alias=rq_alias, + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + else: + request_manager = await self.get_request_manager() - enqueue_links_kwargs: EnqueueLinksKwargs = {k: v for k, v in add_requests_call.items() if k != 'requests'} # type: ignore[assignment] + context_aware_requests = list[Request]() + base_url = kwargs.get('base_url') or context.request.loaded_url or context.request.url + requests_iterator = self._convert_url_to_request_iterator(requests, base_url) + filter_requests_iterator = self._enqueue_links_filter_iterator(requests_iterator, context.request.url, **kwargs) + for dst_request in filter_requests_iterator: + # Update the crawl depth of the request. + dst_request.crawl_depth = context.request.crawl_depth + 1 - filter_requests_iterator = self._enqueue_links_filter_iterator( - requests_iterator, context.request.url, **enqueue_links_kwargs - ) + if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth: + context_aware_requests.append(dst_request) - for dst_request in filter_requests_iterator: - # Update the crawl depth of the request. - dst_request.crawl_depth = context.request.crawl_depth + 1 + return await request_manager.add_requests(context_aware_requests) - if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth: - requests.append(dst_request) + async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None: + """Commit request handler result for the input `context`. Result is taken from `_context_result_map`.""" + result = self._context_result_map[context] - await request_manager.add_requests(requests) + for add_requests_call in result.add_requests_calls: + await self._add_requests(context, **add_requests_call) for push_data_call in result.push_data_calls: await self._push_data(**push_data_call) diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index c7dad2725c..be22d1f951 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -284,6 +284,46 @@ async def failed_request_handler(context: BasicCrawlingContext, error: Exception assert isinstance(calls[0][1], RuntimeError) +@pytest.mark.parametrize('handler', ['failed_request_handler', 'error_handler']) +async def test_handlers_use_context_helpers(tmp_path: Path, handler: str) -> None: + """Test that context helpers used in `failed_request_handler` and in `error_handler` have effect.""" + # Prepare crawler + storage_client = FileSystemStorageClient() + crawler = BasicCrawler( + max_request_retries=1, storage_client=storage_client, configuration=Configuration(storage_dir=str(tmp_path)) + ) + # Test data + rq_alias = 'other' + test_data = {'some': 'data'} + test_key = 'key' + test_value = 'value' + test_request = Request.from_url('https://d.placeholder.com') + + # Request handler with injected error + @crawler.router.default_handler + async def request_handler(context: BasicCrawlingContext) -> None: + raise RuntimeError('Arbitrary crash for testing purposes') + + # Apply one of the handlers + @getattr(crawler, handler) # type:ignore[misc] # Untyped decorator is ok to make the test concise + async def handler_implementation(context: BasicCrawlingContext, error: Exception) -> None: + await context.push_data(test_data) + await context.add_requests(requests=[test_request], rq_alias=rq_alias) + kvs = await context.get_key_value_store() + await kvs.set_value(test_key, test_value) + + await crawler.run(['https://b.placeholder.com']) + + # Verify that the context helpers used in handlers had effect on used storages + dataset = await Dataset.open(storage_client=storage_client) + kvs = await KeyValueStore.open(storage_client=storage_client) + rq = await RequestQueue.open(alias=rq_alias, storage_client=storage_client) + + assert test_value == await kvs.get_value(test_key) + assert [test_data] == (await dataset.get_data()).items + assert test_request == await rq.fetch_next_request() + + async def test_handles_error_in_failed_request_handler() -> None: crawler = BasicCrawler(max_request_retries=3)