Skip to content

Commit

Permalink
Merge dc949da into e13e0ed
Browse files Browse the repository at this point in the history
  • Loading branch information
JWCook committed Apr 6, 2021
2 parents e13e0ed + dc949da commit e379e57
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 28 deletions.
59 changes: 50 additions & 9 deletions aiohttp_client_cache/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,49 @@ class BaseCache(metaclass=ABCMeta):
"""A wrapper for lower-level cache storage operations. This is separate from
:py:class:`.CacheBackend` to allow a single backend to contain multiple cache objects.
This is no longer using a dict-like interface due to lack of python syntax support for async
dict operations.
Args:
secret_key: Optional secret key used to sign cache items for added security
salt: Optional salt used to sign cache items
serializer: Custom serializer that provides ``loads`` and ``dumps`` methods
"""

def __init__(
self,
secret_key: Union[Iterable, str, bytes] = None,
salt: Union[str, bytes] = b'aiohttp-client-cache',
serializer=None,
**kwargs,
):
super().__init__()
self._serializer = serializer or self._get_serializer(secret_key, salt)

def serialize(self, item: Union[CachedResponse, str] = None) -> Optional[bytes]:
"""Serialize a URL or response into bytes"""
return self._serializer.dumps(item) if item else None

def deserialize(self, item: Union[str, bytes] = None) -> Union[CachedResponse, str, None]:
"""Deserialize a cached URL or response"""
return self._serializer.loads(bytes(item)) if item else None

# TODO: Remove once all backends have been updated to use serialize/deserialize
@staticmethod
def unpickle(result):
return pickle.loads(bytes(result)) if result else None

@staticmethod
def _get_serializer(secret_key, salt):
"""Get the appropriate serializer to use; either ``itsdangerous``, if a secret key is
specified, or plain ``pickle`` otherwise.
Raises:
py:exc:`ImportError` if ``secret_key`` is specified but ``itsdangerous`` is not installed
"""
if secret_key:
from itsdangerous.serializer import Serializer

return Serializer(secret_key, salt=salt, serializer=pickle)
else:
return pickle

@abstractmethod
async def contains(self, key: str) -> bool:
"""Check if a key is stored in the cache"""
Expand Down Expand Up @@ -339,10 +378,6 @@ async def values(self) -> Iterable[ResponseOrKey]:
async def write(self, key: str, item: ResponseOrKey):
"""Write an item to the cache"""

@staticmethod
def unpickle(result):
return pickle.loads(bytes(result)) if result else None

async def pop(self, key: str, default=None) -> ResponseOrKey:
"""Delete an item from the cache, and return the deleted item"""
try:
Expand All @@ -357,7 +392,10 @@ class DictCache(BaseCache, UserDict):
"""Simple in-memory storage that wraps a dict with the :py:class:`.BaseStorage` interface"""

async def delete(self, key: str):
del self.data[key]
try:
del self.data[key]
except KeyError:
pass

async def clear(self):
self.data.clear()
Expand All @@ -368,8 +406,11 @@ async def contains(self, key: str) -> bool:
async def keys(self) -> Iterable[str]: # type: ignore
return self.data.keys()

async def read(self, key: str) -> Union[CachedResponse, str]:
return self.data[key]
async def read(self, key: str) -> Union[CachedResponse, str, None]:
try:
return self.data[key]
except KeyError:
return None

async def size(self) -> int:
return len(self.data)
Expand Down
1 change: 1 addition & 0 deletions aiohttp_client_cache/backends/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: Use BaseCache.serialize() and deserialize()
import pickle
from typing import Dict, Iterable

Expand Down
5 changes: 2 additions & 3 deletions aiohttp_client_cache/backends/gridfs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
from typing import Iterable

from gridfs import GridFS
Expand Down Expand Up @@ -62,7 +61,7 @@ async def read(self, key: str) -> ResponseOrKey:
result = self.fs.find_one({'_id': key})
if result is None:
raise KeyError
return self.unpickle(bytes(result.read()))
return self.deserialize(result.read())

async def size(self) -> int:
return self.db['fs.files'].count()
Expand All @@ -73,4 +72,4 @@ async def values(self) -> Iterable[ResponseOrKey]:

async def write(self, key: str, item: ResponseOrKey):
await self.delete(key)
self.fs.put(pickle.dumps(item, protocol=-1), **{'_id': key})
self.fs.put(self.serialize(item), **{'_id': key})
10 changes: 6 additions & 4 deletions aiohttp_client_cache/backends/mongo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
from typing import Iterable

from motor.motor_asyncio import AsyncIOMotorClient
Expand Down Expand Up @@ -34,7 +33,10 @@ class MongoDBCache(BaseCache):
connection: MongoDB connection instance to use instead of creating a new one
"""

def __init__(self, db_name, collection_name: str, connection: AsyncIOMotorClient = None):
def __init__(
self, db_name, collection_name: str, connection: AsyncIOMotorClient = None, **kwargs
):
super().__init__(**kwargs)
self.connection = connection or AsyncIOMotorClient()
self.db = self.connection[db_name]
self.collection = self.db[collection_name]
Expand Down Expand Up @@ -75,7 +77,7 @@ class MongoDBPickleCache(MongoDBCache):
"""Same as :py:class:`MongoDBCache`, but pickles values before saving"""

async def read(self, key):
return self.unpickle(bytes(await super().read(key)))
return self.deserialize(await super().read(key))

async def write(self, key, item):
await super().write(key, pickle.dumps(item, protocol=-1))
await super().write(key, self.serialize(item))
13 changes: 9 additions & 4 deletions aiohttp_client_cache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
from typing import Iterable

from aioredis import Redis, create_redis_pool
Expand Down Expand Up @@ -40,6 +39,12 @@ def __init__(
connection: Redis = None,
**kwargs,
):
# Pop off any BaseCache kwargs and use the rest as Redis connection kwargs
super().__init__(
secret_key=kwargs.pop('secret_key', None),
salt=kwargs.pop('salt', None),
serializer=kwargs.pop('serializer', None),
)
self.address = address
self._connection = connection
self.connection_kwargs = kwargs
Expand Down Expand Up @@ -72,20 +77,20 @@ async def keys(self) -> Iterable[str]:
async def read(self, key: str) -> ResponseOrKey:
connection = await self.get_connection()
result = await connection.hget(self.hash_key, key)
return self.unpickle(result)
return self.deserialize(result)

async def size(self) -> int:
connection = await self.get_connection()
return await connection.hlen(self.hash_key)

async def values(self) -> Iterable[ResponseOrKey]:
connection = await self.get_connection()
return [self.unpickle(v) for v in await connection.hvals(self.hash_key)]
return [self.deserialize(v) for v in await connection.hvals(self.hash_key)]

async def write(self, key: str, item: ResponseOrKey):
connection = await self.get_connection()
await connection.hset(
self.hash_key,
key,
pickle.dumps(item, protocol=-1),
self.serialize(item),
)
12 changes: 5 additions & 7 deletions aiohttp_client_cache/backends/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import pickle
import sqlite3
from contextlib import asynccontextmanager
from os.path import expanduser, splitext
Expand Down Expand Up @@ -48,7 +47,8 @@ class SQLiteCache(BaseCache):
table_name: table name
"""

def __init__(self, filename: str, table_name: str):
def __init__(self, filename: str, table_name: str, **kwargs):
super().__init__(**kwargs)
self.filename = filename
self.table_name = table_name
self._can_commit = True # Transactions can be committed if this is set to `True`
Expand Down Expand Up @@ -156,14 +156,12 @@ class SQLitePickleCache(SQLiteCache):
""" Same as :py:class:`SqliteCache`, but pickles values before saving """

async def read(self, key: str) -> ResponseOrKey:
item = await super().read(key)
return pickle.loads(bytes(item)) if item else None # type: ignore
return self.deserialize(await super().read(key))

async def values(self) -> Iterable[ResponseOrKey]:
async with self.get_connection() as db:
cur = await db.execute(f'select value from `{self.table_name}`')
return [self.unpickle(row[0]) for row in await cur.fetchall()]
return [self.deserialize(row[0]) for row in await cur.fetchall()]

async def write(self, key, item):
binary_item = sqlite3.Binary(pickle.dumps(item, protocol=-1))
await super().write(key, binary_item)
await super().write(key, sqlite3.Binary(self.serialize(item)))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
packages=find_packages(),
include_package_data=True,
version=__version__,
install_requires=['aiohttp', 'attrs', 'python-forge', 'url-normalize'],
install_requires=['aiohttp', 'attrs', 'itsdangerous', 'python-forge', 'url-normalize'],
extras_require=extras_require,
zip_safe=False,
)

0 comments on commit e379e57

Please sign in to comment.