-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
237 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,149 +1,192 @@ | ||
# TODO: Use BaseCache.serialize() and deserialize() | ||
import pickle | ||
from typing import Dict, Iterable | ||
import logging | ||
from typing import AsyncIterable, Dict, Optional | ||
|
||
import boto3 | ||
from boto3.resources.base import ServiceResource | ||
import aioboto3 | ||
from aioboto3.session import ResourceCreatorContext | ||
from botocore.exceptions import ClientError | ||
|
||
from aiohttp_client_cache.backends import BaseCache, CacheBackend, ResponseOrKey | ||
from aiohttp_client_cache.forge_utils import extend_signature | ||
from aiohttp_client_cache.response import CachedResponse | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DynamoDBBackend(CacheBackend): | ||
"""DynamoDB cache backend. | ||
See :py:class:`.DynamoDbCache` for backend-specific options | ||
See `DynamoDB Service Resource | ||
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#service-resource>`_ | ||
for more usage details. | ||
See :py:class:`.CacheBackend` for args. | ||
""" | ||
|
||
@extend_signature(CacheBackend.__init__) | ||
def __init__(self, cache_name: str = 'aiohttp-cache', **kwargs): | ||
def __init__( | ||
self, | ||
cache_name: str = 'aiohttp-cache', | ||
key_attr_name: str = 'k', | ||
val_attr_name: str = 'v', | ||
create_if_not_exists: bool = False, | ||
context: ResourceCreatorContext = None, | ||
**kwargs, | ||
): | ||
super().__init__(cache_name=cache_name, **kwargs) | ||
self.responses = DynamoDbCache(cache_name, 'responses', **kwargs) | ||
if not context: | ||
context = aioboto3.resource("dynamodb") | ||
self.responses = DynamoDbCache( | ||
cache_name, 'resp', key_attr_name, val_attr_name, create_if_not_exists, context | ||
) | ||
self.redirects = DynamoDbCache( | ||
cache_name, 'redirects', connection=self.responses.connection | ||
cache_name, 'redir', key_attr_name, val_attr_name, create_if_not_exists, context | ||
) | ||
|
||
async def get_response(self, key: str) -> Optional[CachedResponse]: | ||
logger.debug(f'Attempting to get cached response for key: {key}') | ||
|
||
# Avoiding calling contains here | ||
response = await self.responses.read(key) | ||
if not response: | ||
redirect_key = await self.redirects.read(key) | ||
if redirect_key: | ||
response = await self.responses.read(redirect_key) | ||
|
||
if not response: | ||
logger.debug('No cached response found') | ||
return None | ||
|
||
# If the item is expired or filtered out, delete it from the cache | ||
if not self.is_cacheable(response): | ||
logger.info('Cached response expired; deleting') | ||
await self.delete(key) | ||
return None | ||
|
||
logger.info(f'Cached response found for key: {key}') | ||
return response | ||
|
||
|
||
# TODO: Incomplete/untested | ||
# TODO: Fully async implementation. Current implementation with boto3 uses blocking operations. | ||
# Methods are currently defined as async only for compatibility with BaseCache API. | ||
class DynamoDbCache(BaseCache): | ||
"""An async-compatible interface for caching objects in a DynamoDB key-store | ||
The actual key name on the dynamodb server will be ``namespace:table_name``. | ||
The actual key name on the dynamodb server will be ``namespace:key``. | ||
In order to deal with how dynamodb stores data/keys, all values must be pickled. | ||
Args: | ||
table_name: Table name to use | ||
namespace: Name of the hash map stored in dynamodb | ||
connection: An existing resource object to reuse instead of creating a new one | ||
region_name: AWS region of DynamoDB database | ||
kwargs: Additional keyword arguments for DynamoDB :py:class:`.ServiceResource` | ||
namespace: Prefix to be prepended to key in the DynamoDB document | ||
key_attr_name: The name of the field to use for keys in the DynamoDB document | ||
val_attr_name: The name of the field to use for values in the DynamoDB document | ||
create_if_not_exists: Whether or not to attempt to create the DynamoDB table | ||
context: An existing ResourceCreatorContext (See aioboto3.resource() ) to reuse instead of creating a new one | ||
""" | ||
|
||
def __init__( | ||
self, | ||
table_name: str, | ||
namespace: str = 'dynamodb_dict_data', | ||
connection: ServiceResource = None, | ||
region_name: str = 'us-east-1', | ||
read_capacity_units: int = 1, | ||
write_capacity_units: int = 1, | ||
namespace: str, | ||
key_attr_name: str, | ||
val_attr_name: str, | ||
create_if_not_exists: bool, | ||
context: ResourceCreatorContext, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.table_name = table_name | ||
self.namespace = namespace | ||
self.connection = connection or boto3.resource( | ||
'dynamodb', region_name=region_name, **kwargs | ||
) | ||
|
||
# Create the table if it doesn't already exist | ||
try: | ||
self.connection.create_table( | ||
AttributeDefinitions=[ | ||
{ | ||
'AttributeName': 'namespace', | ||
'AttributeType': 'S', | ||
}, | ||
{ | ||
'AttributeName': 'key', | ||
'AttributeType': 'S', | ||
}, | ||
], | ||
TableName=table_name, | ||
KeySchema=[ | ||
{'AttributeName': 'namespace', 'KeyType': 'HASH'}, | ||
{'AttributeName': 'key', 'KeyType': 'RANGE'}, | ||
], | ||
ProvisionedThroughput={ | ||
'ReadCapacityUnits': read_capacity_units, | ||
'WriteCapacityUnits': write_capacity_units, | ||
}, | ||
) | ||
except ClientError: | ||
pass | ||
|
||
self._table = self.connection.Table(table_name) | ||
self._table.wait_until_exists() | ||
|
||
def _scan_table(self) -> Dict: | ||
return self._table.query( | ||
ExpressionAttributeValues={':Namespace': self.namespace}, | ||
ExpressionAttributeNames={'#N': 'namespace'}, | ||
KeyConditionExpression='#N = :Namespace', | ||
self.key_attr_name = key_attr_name | ||
self.val_attr_name = val_attr_name | ||
self.create_if_not_exists = create_if_not_exists | ||
self.context = context | ||
self._table = None | ||
|
||
async def get_table(self): | ||
if not self._table: | ||
# Re-use the service resource if it's already been created | ||
if self.context.cls: | ||
conn = self.context.cls | ||
# otherwise create | ||
else: | ||
# should we try to call aexit later if we auto enter here? | ||
conn = await self.context.__aenter__() | ||
|
||
self._table = await conn.Table(self.table_name) | ||
if self.create_if_not_exists: | ||
try: | ||
await conn.create_table( | ||
AttributeDefinitions=[ | ||
{ | ||
'AttributeName': self.key_attr_name, | ||
'AttributeType': 'S', | ||
}, | ||
], | ||
TableName=self.table_name, | ||
KeySchema=[ | ||
{ | ||
'AttributeName': self.key_attr_name, | ||
'KeyType': 'HASH', | ||
}, | ||
], | ||
BillingMode="PAY_PER_REQUEST", | ||
) | ||
await self._table.wait_until_exists() | ||
except ClientError as e: | ||
if e.response["Error"]["Code"] != "ResourceInUseException": | ||
raise | ||
|
||
return self._table | ||
|
||
def _doc(self, key) -> Dict: | ||
return {self.key_attr_name: f'{self.namespace}:{key}'} | ||
|
||
async def _scan(self) -> AsyncIterable[Dict]: | ||
table = await self.get_table() | ||
client = table.meta.client | ||
paginator = client.get_paginator('scan') | ||
iterator = paginator.paginate( | ||
TableName=table.name, | ||
Select='ALL_ATTRIBUTES', | ||
FilterExpression=f'begins_with({self.key_attr_name}, :namespace)', | ||
ExpressionAttributeValues={':namespace': f'{self.namespace}:'}, | ||
) | ||
async for result in iterator: | ||
for item in result['Items']: | ||
yield item | ||
|
||
@staticmethod | ||
def unpickle(response_item: Dict) -> ResponseOrKey: | ||
return BaseCache.unpickle((response_item or {}).get('value')) | ||
async def delete(self, key: str) -> None: | ||
doc = self._doc(key) | ||
table = await self.get_table() | ||
await table.delete_item(Key=doc) | ||
|
||
async def clear(self): | ||
response = self._scan_table() | ||
for v in response['Items']: | ||
composite_key = {'namespace': v['namespace'], 'key': v['key']} | ||
self._table.delete_item(Key=composite_key) | ||
async def read(self, key: str) -> ResponseOrKey: | ||
table = await self.get_table() | ||
response = await table.get_item(Key=self._doc(key), ProjectionExpression=self.val_attr_name) | ||
item = response.get("Item") | ||
if item: | ||
return self.deserialize(item[self.val_attr_name].value) | ||
|
||
async def write(self, key: str, item: ResponseOrKey) -> None: | ||
table = await self.get_table() | ||
doc = self._doc(key) | ||
doc[self.val_attr_name] = self.serialize(item) | ||
await table.put_item(Item=doc) | ||
|
||
async def clear(self) -> None: | ||
async for key in self.keys(): | ||
await self.delete(key) | ||
|
||
# TODO | ||
async def contains(self, key: str) -> bool: | ||
raise NotImplementedError | ||
|
||
async def delete(self, key: str): | ||
composite_key = {'namespace': self.namespace, 'key': str(key)} | ||
response = self._table.delete_item(Key=composite_key, ReturnValues='ALL_OLD') | ||
if 'Attributes' not in response: | ||
raise KeyError | ||
resp = await self.read(key) | ||
return resp is not None | ||
|
||
# TODO | ||
async def keys(self) -> Iterable[str]: | ||
raise NotImplementedError | ||
|
||
async def read(self, key: str) -> ResponseOrKey: | ||
response = self._table.get_item(Key={'namespace': self.namespace, 'key': str(key)}) | ||
return self.unpickle(response.get('Item')) | ||
async def keys(self) -> AsyncIterable[str]: | ||
len_prefix = len(self.namespace) + 1 | ||
async for item in self._scan(): | ||
yield item[self.key_attr_name][len_prefix:] | ||
|
||
async def size(self) -> int: | ||
expression_attribute_values = {':Namespace': self.namespace} | ||
expression_attribute_names = {'#N': 'namespace'} | ||
key_condition_expression = '#N = :Namespace' | ||
return self._table.query( | ||
Select='COUNT', | ||
ExpressionAttributeValues=expression_attribute_values, | ||
ExpressionAttributeNames=expression_attribute_names, | ||
KeyConditionExpression=key_condition_expression, | ||
)['Count'] | ||
|
||
async def values(self) -> Iterable[ResponseOrKey]: | ||
response = self._scan_table() | ||
return [self.unpickle(item) for item in response.get('Items', [])] | ||
|
||
async def write(self, key: str, item: ResponseOrKey): | ||
item_meta = { | ||
'namespace': self.namespace, | ||
'key': str(key), | ||
'value': pickle.dumps(item, protocol=-1), | ||
} | ||
self._table.put_item(Item=item_meta) | ||
count = 0 | ||
async for item in self._scan(): | ||
count += 1 | ||
return count | ||
|
||
async def values(self) -> AsyncIterable[ResponseOrKey]: | ||
async for item in self._scan(): | ||
yield self.deserialize(item[self.val_attr_name].value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.