Skip to content

Commit

Permalink
Add support for per-request expiration, and related refactoring:
Browse files Browse the repository at this point in the history
* Move all expiration logic into a separate module, and add more tests
* If `expire_after` is numeric, expect in in seconds instead of hours (for consistency with Cache-Control)
* Add some more unit test coverage for CacheBackend
  • Loading branch information
JWCook committed Apr 8, 2021
1 parent 7e7edbd commit 0b64b5d
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 155 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ progress on major and minor releases.
Here is a brief overview of the main classes and modules:
* `session.CacheMixin`, `session.CachedSession`: A mixin and wrapper class, respectively, for `aiohttp.ClientSession`. There is little logic here except wrapping `ClientSession._request()` with caching behavior.
* `response.CachedResponse`: A wrapper class built from an `aiohttp.ClientResponse`, with additional cache-related info. This is what is serialized and persisted to the cache.
* `backends.base.CacheBackend`: Most of the caching logic lives here, including saving and retriving responses, creating cache keys, expiration, etc. It contains two `BaseCache` objects for storing responses and redirects, respectively. By default this is just a non-persistent dict cache.
* `backends.base.CacheBackend`: Most of the caching logic lives here, including saving and retriving responses. It contains two `BaseCache` objects for storing responses and redirects, respectively.
* `cache_keys` and `expiration`: Utilities for creating cache keys and cache expiration, respectively
* `backends.base.BaseCache`: Base class for lower-level storage operations, overridden by individual backends.
* Other backend implementations in `backends.*`: A backend implementation subclasses `CacheBackend` (for higher-level operations), as well as `BaseCache` (for lower-level operations).
3 changes: 2 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
[See all issues & PRs here](https://github.com/JWCook/aiohttp-client-cache/milestone/2?closed=1)

* Add async implementation of DynamoDb backend
* Add support for setting different expiration times based on URL patterns
* Add support for expiration for individual requests
* Add support for expiration based on URL patterns
* Add support for serializing/deserializing `ClientSession.links`
* Add case-insensitive response headers for compatibility with aiohttp.ClientResponse.headers
* Add optional integration with `itsdangerous` for safer serialization
Expand Down
109 changes: 40 additions & 69 deletions aiohttp_client_cache/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import pickle
from abc import ABCMeta, abstractmethod
from collections import UserDict
from datetime import datetime, timedelta
from fnmatch import fnmatch as glob_match
from datetime import datetime
from logging import getLogger
from typing import AsyncIterable, Callable, Dict, Iterable, Optional, Union
from urllib.parse import urlsplit
from typing import AsyncIterable, Callable, Iterable, Optional, Union

from aiohttp import ClientResponse
from aiohttp.typedefs import StrOrURL

from aiohttp_client_cache.cache_keys import create_key
from aiohttp_client_cache.expiration import ExpirationPatterns, ExpirationTime, get_expiration
from aiohttp_client_cache.response import AnyResponse, CachedResponse

ResponseOrKey = Union[CachedResponse, bytes, str, None]
ExpirationPatterns = Dict[str, Optional[timedelta]]
ExpirationTime = Union[int, float, timedelta, None]
logger = getLogger(__name__)


Expand All @@ -32,8 +29,8 @@ class CacheBackend:
Args:
cache_name: Cache prefix or namespace, depending on backend (see notes below)
expire_after: Expiration time, in hours, after which a cache entry will expire;
set to ``None`` to never expire
expire_after: Expiration time after which a cache entry will expire; may be a numeric value
in seconds, a :py:class:`.timedelta`, or ``-1`` to never expire
urls_expire_after: Expiration times to apply for different URL patterns (see notes below)
allowed_codes: Only cache responses with these status codes
allowed_methods: Only cache requests with these HTTP methods
Expand Down Expand Up @@ -75,16 +72,16 @@ class CacheBackend:
Example::
urls_expire_after = {
'*.site_1.com': 24,
'site_2.com/resource_1': 24 * 2,
'site_2.com/resource_2': 24 * 7,
'site_2.com/static': None,
'*.site_1.com': timedelta(days=1),
'site_2.com/resource_1': timedelta(hours=12),
'site_2.com/resource_2': 60,
'site_2.com/static': -1,
}
Notes:
* ``urls_expire_after`` should be a dict in the format ``{'pattern': expiration_time}``
* ``expiration_time`` may be either a number (in hours) or a ``timedelta``
* ``expiration_time`` may be either a number (in seconds) or a ``timedelta``
(same as ``expire_after``)
* Patterns will match request **base URLs**, so the pattern ``site.com/base`` is equivalent to
``https://site.com/base/**``
Expand All @@ -96,8 +93,8 @@ class CacheBackend:
def __init__(
self,
cache_name: str = 'aiohttp-cache',
expire_after: ExpirationTime = None,
urls_expire_after: Dict[str, ExpirationTime] = None,
expire_after: ExpirationTime = -1,
urls_expire_after: ExpirationPatterns = None,
allowed_codes: tuple = (200,),
allowed_methods: tuple = ('GET', 'HEAD'),
include_headers: bool = False,
Expand All @@ -108,10 +105,8 @@ def __init__(
serializer=None,
):
self.name = cache_name
self.expire_after = _convert_timedelta(expire_after)
self.urls_expire_after: ExpirationPatterns = {
_format_pattern(k): _convert_timedelta(v) for k, v in (urls_expire_after or {}).items()
}
self.expire_after = expire_after
self.urls_expire_after = urls_expire_after
self.allowed_codes = allowed_codes
self.allowed_methods = allowed_methods
self.filter_fn = filter_fn
Expand All @@ -138,25 +133,6 @@ def is_cacheable(self, response: Union[AnyResponse, None]) -> bool:
logger.debug(f'Pre-cache checks for response from {response.url}: {cache_criteria}') # type: ignore
return all(cache_criteria.values())

def get_expiration_date(self, response: ClientResponse) -> Optional[datetime]:
"""Get the absolute expiration time for a response, applying URL patterns if available"""
try:
expire_after = self._get_expiration_for_url(response)
except Exception:
expire_after = self.expire_after
return None if expire_after is None else datetime.utcnow() + expire_after

def _get_expiration_for_url(self, response: ClientResponse) -> Optional[timedelta]:
"""Get the relative expiration time matching the specified URL, if any. If there is no
match, raise a ``ValueError`` to differentiate beween this case and a matching pattern with
``expire_after=None``
"""
for pattern, expire_after in self.urls_expire_after.items():
if glob_match(_base_url(response.url), pattern):
logger.debug(f'URL {response.url} matched pattern "{pattern}": {expire_after}')
return expire_after
raise ValueError('No matching URL pattern')

async def get_response(self, key: str) -> Optional[CachedResponse]:
"""Retrieve response and timestamp for `key` if it's stored in cache,
otherwise returns ``None```
Expand Down Expand Up @@ -189,25 +165,37 @@ async def _get_redirect_response(self, key: str) -> Optional[CachedResponse]:
redirect_key = await self.redirects.read(key)
return await self.responses.read(redirect_key) if redirect_key else None # type: ignore

async def save_response(self, key: str, response: ClientResponse):
async def save_response(
self, key: str, response: ClientResponse, expire_after: ExpirationTime = None
):
"""Save response to cache
Args:
key: Key for this response
response: Response to save
expire_after: Expiration time to set only for this request; overrides
``CachedSession.expire_after`, and accepts all the same values.
"""
if not self.is_cacheable(response):
return
logger.info(f'Saving response for key: {key}')

expires = self.get_expiration_date(response)
cached_response = await CachedResponse.from_client_response(response, expires)
expire_after = self._get_expiration(response, expire_after)
cached_response = await CachedResponse.from_client_response(response, expire_after)
await self.responses.write(key, cached_response)

# Alias any redirect requests to the same cache key
for r in response.history:
await self.redirects.write(self.create_key(r.method, r.url), key)

def _get_expiration(
self, response: ClientResponse, request_expire_after: ExpirationTime = None
) -> Optional[datetime]:
"""Get the appropriate expiration for the given response"""
return get_expiration(
response, request_expire_after, self.expire_after, self.urls_expire_after
)

async def clear(self):
"""Clear cache"""
logger.info('Clearing cache')
Expand All @@ -228,23 +216,16 @@ async def delete_history(response):
await delete_history(await self.responses.pop(key))
await delete_history(await self.responses.pop(redirect_key))

async def delete_url(self, url: StrOrURL):
"""Delete cached response associated with `url`, along with its history (if applicable).
Works only for GET requests.
"""
await self.delete(self.create_key('GET', url))

async def delete_expired_responses(self):
"""Deletes entries from cache with creation time older than ``expire_after``.
**Note:** Also deletes any cache items that are filtered out according to ``filter_fn()``
and filter parameters (``allowable_*``)
"""Deletes all expired responses from the cache.
Also deletes any cache items that are filtered out according to ``filter_fn()``.
"""
logger.info(f'Deleting all responses more than {self.expire_after} hours old')
logger.info('Deleting all expired responses')
keys_to_delete = set()

async for key in self.responses.keys():
response = await self.get_response(key)
if response and response.is_expired:
response = await self.responses.read(key)
if response and response.is_expired or not self.filter_fn(response):
keys_to_delete.add(key)

logger.info(f'Deleting {len(keys_to_delete)} expired cache entries')
Expand All @@ -260,6 +241,12 @@ def create_key(self, *args, **kwargs):
**kwargs,
)

async def delete_url(self, url: StrOrURL):
"""Delete cached response associated with `url`, along with its history (if applicable).
Works only for GET requests.
"""
await self.delete(self.create_key('GET', url))

async def has_url(self, url: StrOrURL) -> bool:
"""Returns `True` if cache has `url`, `False` otherwise. Works only for GET request urls"""
key = self.create_key('GET', url)
Expand Down Expand Up @@ -399,19 +386,3 @@ async def values(self) -> AsyncIterable[ResponseOrKey]: # type: ignore

async def write(self, key: str, item: ResponseOrKey):
self.data[key] = item


def _base_url(url: StrOrURL) -> str:
url = str(url)
return url.replace(urlsplit(url).scheme + '://', '')


def _convert_timedelta(expire_after: ExpirationTime = None) -> Optional[timedelta]:
if expire_after is not None and not isinstance(expire_after, timedelta):
expire_after = timedelta(hours=expire_after)
return expire_after


def _format_pattern(pattern: str) -> str:
"""Add recursive wildcard to a glob pattern, to ensure it matches base URLs"""
return pattern.rstrip('*') + '**'
2 changes: 1 addition & 1 deletion aiohttp_client_cache/backends/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
):
super().__init__(cache_name=cache_name, **kwargs)
self.responses = MongoDBPickleCache(cache_name, 'responses', connection, **kwargs)
self.keys_map = MongoDBCache(cache_name, 'redirects', self.responses.connection, **kwargs)
self.redirects = MongoDBCache(cache_name, 'redirects', self.responses.connection, **kwargs)


class MongoDBCache(BaseCache):
Expand Down
4 changes: 3 additions & 1 deletion aiohttp_client_cache/cache_keys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Functions for creating keys used for cache requests"""
import hashlib
from collections.abc import Mapping
from typing import Dict, List
Expand Down Expand Up @@ -41,7 +42,8 @@ def create_key(


def filter_ignored_params(data, ignored_params):
if not isinstance(data, Mapping):
"""Remove any ignored params from an object, if it's dict-like"""
if not isinstance(data, Mapping) or not ignored_params:
return data
return {k: v for k, v in data.items() if k not in ignored_params}

Expand Down
79 changes: 79 additions & 0 deletions aiohttp_client_cache/expiration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Functions for determining cache expiration"""
from datetime import datetime, timedelta
from fnmatch import fnmatch
from logging import getLogger
from typing import Dict, Optional, Union

from aiohttp import ClientResponse
from aiohttp.typedefs import StrOrURL

ExpirationTime = Union[None, int, float, datetime, timedelta]
ExpirationPatterns = Dict[str, ExpirationTime]
logger = getLogger(__name__)


def get_expiration(
response: ClientResponse,
request_expire_after: ExpirationTime = None,
session_expire_after: ExpirationTime = None,
urls_expire_after: ExpirationPatterns = None,
) -> Optional[datetime]:
"""Get the appropriate expiration for the given response, in order of precedence:
1. Per-request expiration
2. Per-URL expiration
3. Per-session expiration
Returns:
An absolute expiration :py:class:`.datetime` or ``None``
"""
return get_expiration_datetime(
request_expire_after
or get_expiration_for_url(response.url, urls_expire_after)
or session_expire_after
)


def get_expiration_datetime(expire_after: ExpirationTime) -> Optional[datetime]:
"""Convert a relative time value or delta to an absolute datetime, if it's not already"""
logger.debug(f'Determining expiration time based on: {expire_after}')
if expire_after is None or expire_after == -1:
return None
elif isinstance(expire_after, datetime):
return expire_after

if not isinstance(expire_after, timedelta):
expire_after = timedelta(seconds=expire_after)
return datetime.utcnow() + expire_after


def get_expiration_for_url(
url: StrOrURL, urls_expire_after: ExpirationPatterns = None
) -> ExpirationTime:
"""Check for a matching per-URL expiration, if any"""
for pattern, expire_after in (urls_expire_after or {}).items():
if url_match(url, pattern):
logger.debug(f'URL {url} matched pattern "{pattern}": {expire_after}')
return expire_after
return None


def url_match(url: StrOrURL, pattern: str) -> bool:
"""Determine if a URL matches a pattern
Args:
url: URL to test. Its base URL (without protocol) will be used.
pattern: Glob pattern to match against. A recursive wildcard will be added if not present
Example:
>>> url_match('https://httpbin.org/delay/1', 'httpbin.org/delay')
True
>>> url_match('https://httpbin.org/stream/1', 'httpbin.org/*/1')
True
>>> url_match('https://httpbin.org/stream/2', 'httpbin.org/*/1')
False
"""
if not url:
return False
url = str(url).split('://')[-1]
pattern = pattern.split('://')[-1].rstrip('*') + '**'
return fnmatch(url, pattern)
3 changes: 3 additions & 0 deletions aiohttp_client_cache/response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from datetime import datetime
from http.cookies import SimpleCookie
from logging import getLogger
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union

import attr
Expand All @@ -27,6 +28,8 @@
LinkItems = List[Tuple[str, DictItems]]
LinkMultiDict = MultiDictProxy[MultiDictProxy[Union[str, URL]]]

logger = getLogger(__name__)


@attr.s(slots=True)
class CachedResponse:
Expand Down
7 changes: 5 additions & 2 deletions aiohttp_client_cache/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aiohttp.typedefs import StrOrURL

from aiohttp_client_cache.backends import CacheBackend
from aiohttp_client_cache.expiration import ExpirationTime
from aiohttp_client_cache.forge_utils import extend_signature, forge
from aiohttp_client_cache.response import AnyResponse

Expand All @@ -27,7 +28,9 @@ def __init__(self, *, cache: CacheBackend = None, **kwargs):
self.cache = cache or CacheBackend()

@forge.copy(ClientSession._request)
async def _request(self, method: str, str_or_url: StrOrURL, **kwargs) -> AnyResponse:
async def _request(
self, method: str, str_or_url: StrOrURL, expire_after: ExpirationTime = None, **kwargs
) -> AnyResponse:
"""Wrapper around :py:meth:`.SessionClient._request` that adds caching"""
cache_key = self.cache.create_key(method, str_or_url, **kwargs)

Expand All @@ -39,7 +42,7 @@ async def _request(self, method: str, str_or_url: StrOrURL, **kwargs) -> AnyResp
logger.info(f'Cached response not found; making request to {str_or_url}')
new_response = await super()._request(method, str_or_url, **kwargs) # type: ignore
await new_response.read()
await self.cache.save_response(cache_key, new_response)
await self.cache.save_response(cache_key, new_response, expire_after=expire_after)
return new_response

@asynccontextmanager
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ directory = 'test-reports'
[tool.coverage.run]
branch = true
source = ['aiohttp_client_cache']
omit = ['aiohttp_client_cache/backends/__init__.py']

[tool.isort]
profile = "black"
Expand Down
Loading

0 comments on commit 0b64b5d

Please sign in to comment.