Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor cache module #24

Merged
merged 1 commit into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/fastapi_redis_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# flake8: noqa
from fastapi_redis_cache.cache import cache
from fastapi_redis_cache.cache import (
cache,
cache_one_day,
cache_one_hour,
cache_one_minute,
cache_one_month,
cache_one_week,
cache_one_year,
)
from fastapi_redis_cache.client import FastApiRedisCache
64 changes: 44 additions & 20 deletions src/fastapi_redis_cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
"""cache.py"""
import asyncio
from datetime import timedelta
from functools import wraps
from functools import partial, update_wrapper, wraps
from http import HTTPStatus
from typing import Union

from fastapi import Response

from fastapi_redis_cache.client import FastApiRedisCache
from fastapi_redis_cache.util import (
deserialize_json,
ONE_DAY_IN_SECONDS,
ONE_HOUR_IN_SECONDS,
ONE_MONTH_IN_SECONDS,
ONE_WEEK_IN_SECONDS,
ONE_YEAR_IN_SECONDS,
serialize_json,
)


def cache(*, expire_after_seconds: Union[int, timedelta] = None):
def cache(*, expire: Union[int, timedelta] = ONE_YEAR_IN_SECONDS):
"""Enable caching behavior for the decorated function.

If no arguments are provided, this marks the response data for the decorated
path function as "never expires". In this case, the `Expires` and
`Cache-Control: max-age` headers will be set to expire after one year.
Historically, this was the furthest time in the future that was allowed for
these fields. This is no longer the case, but it is still not advisable to use
values greater than one year.

Args:
expire_after_seconds (Union[int, timedelta], optional): The number of seconds
from now when the cached response should expire. Defaults to None.
expire (Union[int, timedelta], optional): The number of seconds
from now when the cached response should expire. Defaults to 31,536,000
seconds (i.e., the number of seconds in one year).
"""

def outer_wrapper(func):
Expand All @@ -33,32 +36,32 @@ async def inner_wrapper(*args, **kwargs):
func_kwargs = kwargs.copy()
request = func_kwargs.pop("request", None)
response = func_kwargs.pop("response", None)
create_response_directly = False
if not response:
create_response_directly = not response
if create_response_directly:
response = Response()
create_response_directly = True
redis_cache = FastApiRedisCache()

# if the redis client is not connected or request is not cacheable, no caching behavior is performed.
if redis_cache.not_connected or redis_cache.request_is_not_cacheable(request):
# if the redis client is not connected or request is not cacheable, no caching behavior is performed.
return await get_api_response_async(func, *args, **kwargs)
key = redis_cache.get_cache_key(func, *args, **kwargs)
ttl, in_cache = redis_cache.check_cache(key)
if in_cache:
if redis_cache.requested_resource_not_modified(request, in_cache):
response.status_code = int(HTTPStatus.NOT_MODIFIED)
return response
cached_data = redis_cache.deserialize_json(in_cache)
cached_data = deserialize_json(in_cache)
redis_cache.set_response_headers(response, cache_hit=True, response_data=cached_data, ttl=ttl)
if create_response_directly:
return Response(content=in_cache, media_type="application/json", headers=response.headers)
return cached_data
response_data = await get_api_response_async(func, *args, **kwargs)
redis_cache.add_to_cache(key, response_data, expire_after_seconds)
redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl)
ttl = calculate_ttl(expire)
cached = redis_cache.add_to_cache(key, response_data, ttl)
if cached:
redis_cache.set_response_headers(response, cache_hit=False, response_data=response_data, ttl=ttl)
if create_response_directly:
return Response(
content=redis_cache.serialize_json(response_data),
content=serialize_json(response_data),
media_type="application/json",
headers=response.headers,
)
Expand All @@ -72,3 +75,24 @@ async def inner_wrapper(*args, **kwargs):
async def get_api_response_async(func, *args, **kwargs):
"""Helper function that allows decorator to work with both async and non-async functions."""
return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)


def calculate_ttl(expire: Union[int, timedelta]) -> int:
if isinstance(expire, timedelta):
expire = int(expire.total_seconds())
return min(expire, ONE_YEAR_IN_SECONDS)


cache_one_minute = partial(cache, expire=60)
cache_one_hour = partial(cache, expire=ONE_HOUR_IN_SECONDS)
cache_one_day = partial(cache, expire=ONE_DAY_IN_SECONDS)
cache_one_week = partial(cache, expire=ONE_WEEK_IN_SECONDS)
cache_one_month = partial(cache, expire=ONE_MONTH_IN_SECONDS)
cache_one_year = partial(cache, expire=ONE_YEAR_IN_SECONDS)

update_wrapper(cache_one_minute, cache)
update_wrapper(cache_one_hour, cache)
update_wrapper(cache_one_day, cache)
update_wrapper(cache_one_week, cache)
update_wrapper(cache_one_month, cache)
update_wrapper(cache_one_year, cache)
6 changes: 6 additions & 0 deletions src/fastapi_redis_cache/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
DATETIME_AWARE = "%m/%d/%Y %I:%M:%S %p %z"
DATE_ONLY = "%m/%d/%Y"

ONE_HOUR_IN_SECONDS = 3600
ONE_DAY_IN_SECONDS = ONE_HOUR_IN_SECONDS * 24
ONE_WEEK_IN_SECONDS = ONE_DAY_IN_SECONDS * 7
ONE_MONTH_IN_SECONDS = ONE_DAY_IN_SECONDS * 30
ONE_YEAR_IN_SECONDS = ONE_DAY_IN_SECONDS * 365

SERIALIZE_OBJ_MAP = {
str(datetime): parser.parse,
str(date): parser.parse,
Expand Down