Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/crawlee/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import re
from collections.abc import Callable, Coroutine, Sequence

from typing_extensions import NotRequired, Required, Unpack
from typing_extensions import NotRequired, Required, Self, Unpack

from crawlee import Glob, Request
from crawlee._request import RequestOptions
Expand Down Expand Up @@ -643,6 +643,25 @@ def __hash__(self) -> int:
"""Return hash of the context. Each context is considered unique."""
return id(self)

def create_modified_copy(
self,
push_data: PushDataFunction | None = None,
add_requests: AddRequestsFunction | None = None,
get_key_value_store: GetKeyValueStoreFromRequestHandlerFunction | None = None,
) -> Self:
"""Create a modified copy of the crawling context with specified changes."""
original_fields = {field.name: getattr(self, field.name) for field in dataclasses.fields(self)}
modified_fields = {
key: value
for key, value in {
'push_data': push_data,
'add_requests': add_requests,
'get_key_value_store': get_key_value_store,
}.items()
if value
}
return self.__class__(**{**original_fields, **modified_fields})


class GetDataKwargs(TypedDict):
"""Keyword arguments for dataset's `get_data` method."""
Expand Down
104 changes: 60 additions & 44 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import asyncio
import functools
import logging
import signal
import sys
Expand All @@ -14,7 +15,7 @@
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast
from urllib.parse import ParseResult, urlparse
from weakref import WeakKeyDictionary

Expand Down Expand Up @@ -96,6 +97,9 @@
TCrawlingContext = TypeVar('TCrawlingContext', bound=BasicCrawlingContext, default=BasicCrawlingContext)
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
TRequestIterator = TypeVar('TRequestIterator', str, Request)
TParams = ParamSpec('TParams')
T = TypeVar('T')

ErrorHandler = Callable[[TCrawlingContext, Exception], Awaitable[Request | None]]
FailedRequestHandler = Callable[[TCrawlingContext, Exception], Awaitable[None]]
SkippedRequestCallback = Callable[[str, SkippedReason], Awaitable[None]]
Expand Down Expand Up @@ -520,6 +524,24 @@ def stop(self, reason: str = 'Stop was called externally.') -> None:
self._logger.info(f'Crawler.stop() was called with following reason: {reason}.')
self._unexpected_stop = True

def _wrap_handler_with_error_context(
self, handler: Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]]
) -> Callable[[TCrawlingContext | BasicCrawlingContext, Exception], Awaitable[T]]:
"""Decorate error handlers to make their context helpers usable."""

@functools.wraps(handler)
async def wrapped_handler(context: TCrawlingContext | BasicCrawlingContext, exception: Exception) -> T:
# Original context helpers that are from `RequestHandlerRunResult` will not be commited as the request
# failed. Modified context provides context helpers with direct access to the storages.
error_context = context.create_modified_copy(
push_data=self._push_data,
get_key_value_store=self.get_key_value_store,
add_requests=functools.partial(self._add_requests, context),
)
return await handler(error_context, exception)

return wrapped_handler

def _stop_if_max_requests_count_exceeded(self) -> None:
"""Call `stop` when the maximum number of requests to crawl has been reached."""
if self._max_requests_per_crawl is None:
Expand Down Expand Up @@ -618,7 +640,7 @@ def error_handler(

The error handler is invoked after a request handler error occurs and before a retry attempt.
"""
self._error_handler = handler
self._error_handler = self._wrap_handler_with_error_context(handler)
return handler

def failed_request_handler(
Expand All @@ -628,7 +650,7 @@ def failed_request_handler(

The failed request handler is invoked when a request has failed all retry attempts.
"""
self._failed_request_handler = handler
self._failed_request_handler = self._wrap_handler_with_error_context(handler)
return handler

def on_skipped_request(self, callback: SkippedRequestCallback) -> SkippedRequestCallback:
Expand Down Expand Up @@ -1256,52 +1278,46 @@ def _convert_url_to_request_iterator(self, urls: Sequence[str | Request], base_u
else:
yield Request.from_url(url)

async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None:
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
result = self._context_result_map[context]

base_request_manager = await self.get_request_manager()

origin = context.request.loaded_url or context.request.url

for add_requests_call in result.add_requests_calls:
rq_id = add_requests_call.get('rq_id')
rq_name = add_requests_call.get('rq_name')
rq_alias = add_requests_call.get('rq_alias')
specified_params = sum(1 for param in [rq_id, rq_name, rq_alias] if param is not None)
if specified_params > 1:
raise ValueError('You can only provide one of `rq_id`, `rq_name` or `rq_alias` arguments.')
if rq_id or rq_name or rq_alias:
request_manager: RequestManager | RequestQueue = await RequestQueue.open(
id=rq_id,
name=rq_name,
alias=rq_alias,
storage_client=self._service_locator.get_storage_client(),
configuration=self._service_locator.get_configuration(),
)
else:
request_manager = base_request_manager

requests = list[Request]()

base_url = url if (url := add_requests_call.get('base_url')) else origin

requests_iterator = self._convert_url_to_request_iterator(add_requests_call['requests'], base_url)
async def _add_requests(
self,
context: BasicCrawlingContext,
requests: Sequence[str | Request],
rq_id: str | None = None,
rq_name: str | None = None,
rq_alias: str | None = None,
**kwargs: Unpack[EnqueueLinksKwargs],
) -> None:
"""Add requests method aware of the crawling context."""
if rq_id or rq_name or rq_alias:
request_manager: RequestManager = await RequestQueue.open(
id=rq_id,
name=rq_name,
alias=rq_alias,
storage_client=self._service_locator.get_storage_client(),
configuration=self._service_locator.get_configuration(),
)
else:
request_manager = await self.get_request_manager()

enqueue_links_kwargs: EnqueueLinksKwargs = {k: v for k, v in add_requests_call.items() if k != 'requests'} # type: ignore[assignment]
context_aware_requests = list[Request]()
base_url = kwargs.get('base_url') or context.request.loaded_url or context.request.url
requests_iterator = self._convert_url_to_request_iterator(requests, base_url)
filter_requests_iterator = self._enqueue_links_filter_iterator(requests_iterator, context.request.url, **kwargs)
for dst_request in filter_requests_iterator:
# Update the crawl depth of the request.
dst_request.crawl_depth = context.request.crawl_depth + 1

filter_requests_iterator = self._enqueue_links_filter_iterator(
requests_iterator, context.request.url, **enqueue_links_kwargs
)
if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth:
context_aware_requests.append(dst_request)

for dst_request in filter_requests_iterator:
# Update the crawl depth of the request.
dst_request.crawl_depth = context.request.crawl_depth + 1
return await request_manager.add_requests(context_aware_requests)

if self._max_crawl_depth is None or dst_request.crawl_depth <= self._max_crawl_depth:
requests.append(dst_request)
async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None:
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
result = self._context_result_map[context]

await request_manager.add_requests(requests)
for add_requests_call in result.add_requests_calls:
await self._add_requests(context, **add_requests_call)

for push_data_call in result.push_data_calls:
await self._push_data(**push_data_call)
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,46 @@ async def failed_request_handler(context: BasicCrawlingContext, error: Exception
assert isinstance(calls[0][1], RuntimeError)


@pytest.mark.parametrize('handler', ['failed_request_handler', 'error_handler'])
async def test_handlers_use_context_helpers(tmp_path: Path, handler: str) -> None:
"""Test that context helpers used in `failed_request_handler` and in `error_handler` have effect."""
# Prepare crawler
storage_client = FileSystemStorageClient()
crawler = BasicCrawler(
max_request_retries=1, storage_client=storage_client, configuration=Configuration(storage_dir=str(tmp_path))
)
# Test data
rq_alias = 'other'
test_data = {'some': 'data'}
test_key = 'key'
test_value = 'value'
test_request = Request.from_url('https://d.placeholder.com')

# Request handler with injected error
@crawler.router.default_handler
async def request_handler(context: BasicCrawlingContext) -> None:
raise RuntimeError('Arbitrary crash for testing purposes')

# Apply one of the handlers
@getattr(crawler, handler) # type:ignore[misc] # Untyped decorator is ok to make the test concise
async def handler_implementation(context: BasicCrawlingContext, error: Exception) -> None:
await context.push_data(test_data)
await context.add_requests(requests=[test_request], rq_alias=rq_alias)
kvs = await context.get_key_value_store()
await kvs.set_value(test_key, test_value)

await crawler.run(['https://b.placeholder.com'])

# Verify that the context helpers used in handlers had effect on used storages
dataset = await Dataset.open(storage_client=storage_client)
kvs = await KeyValueStore.open(storage_client=storage_client)
rq = await RequestQueue.open(alias=rq_alias, storage_client=storage_client)

assert test_value == await kvs.get_value(test_key)
assert [test_data] == (await dataset.get_data()).items
assert test_request == await rq.fetch_next_request()


async def test_handles_error_in_failed_request_handler() -> None:
crawler = BasicCrawler(max_request_retries=3)

Expand Down