Skip to content

Commit

Permalink
fix: correct behaviour to avoid caching "INPROGRESS" records (#295)
Browse files Browse the repository at this point in the history
* fix: correct behaviour to avoid caching "INPROGRESS" records

* docs: add beta flag to utility for initial release(s)

* chore: Change STATUS_CONSTANTS to MappingProxyType

* chore: Fix docstrings

* chore: readability improvements

* chore: move cache conditionals inside of cache methods

* chore: add test for unhandled types

Co-authored-by: Michael Brewer <michael.brewer@gyft.com>
  • Loading branch information
Tom McCarthy and Michael Brewer committed Feb 20, 2021
1 parent 03f7dcd commit f1a8832
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 65 deletions.
66 changes: 39 additions & 27 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import Any, Dict

import jmespath
Expand All @@ -21,7 +22,7 @@

logger = logging.getLogger(__name__)

STATUS_CONSTANTS = {"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"}
STATUS_CONSTANTS = MappingProxyType({"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"})


class DataRecord:
Expand Down Expand Up @@ -81,8 +82,7 @@ def status(self) -> str:
"""
if self.is_expired:
return STATUS_CONSTANTS["EXPIRED"]

if self._status in STATUS_CONSTANTS.values():
elif self._status in STATUS_CONSTANTS.values():
return self._status
else:
raise IdempotencyInvalidStatusError(self._status)
Expand Down Expand Up @@ -214,14 +214,14 @@ def _validate_payload(self, lambda_event: Dict[str, Any], data_record: DataRecor
DataRecord instance
Raises
______
----------
IdempotencyValidationError
Event payload doesn't match the stored record for the given idempotency key
"""
if self.payload_validation_enabled:
lambda_payload_hash = self._get_hashed_payload(lambda_event)
if not data_record.payload_hash == lambda_payload_hash:
if data_record.payload_hash != lambda_payload_hash:
raise IdempotencyValidationError("Payload does not match stored record for this event key")

def _get_expiry_timestamp(self) -> int:
Expand All @@ -238,9 +238,30 @@ def _get_expiry_timestamp(self) -> int:
return int((now + period).timestamp())

def _save_to_cache(self, data_record: DataRecord):
"""
Save data_record to local cache except when status is "INPROGRESS"
NOTE: We can't cache "INPROGRESS" records as we have no way to reflect updates that can happen outside of the
execution environment
Parameters
----------
data_record: DataRecord
DataRecord instance
Returns
-------
"""
if not self.use_local_cache:
return
if data_record.status == STATUS_CONSTANTS["INPROGRESS"]:
return
self._cache[data_record.idempotency_key] = data_record

def _retrieve_from_cache(self, idempotency_key: str):
if not self.use_local_cache:
return
cached_record = self._cache.get(idempotency_key)
if cached_record:
if not cached_record.is_expired:
Expand All @@ -249,11 +270,13 @@ def _retrieve_from_cache(self, idempotency_key: str):
self._delete_from_cache(idempotency_key)

def _delete_from_cache(self, idempotency_key: str):
if not self.use_local_cache:
return
del self._cache[idempotency_key]

def save_success(self, event: Dict[str, Any], result: dict) -> None:
"""
Save record of function's execution completing succesfully
Save record of function's execution completing successfully
Parameters
----------
Expand All @@ -277,8 +300,7 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None:
)
self._update_record(data_record=data_record)

if self.use_local_cache:
self._save_to_cache(data_record)
self._save_to_cache(data_record)

def save_inprogress(self, event: Dict[str, Any]) -> None:
"""
Expand All @@ -298,18 +320,11 @@ def save_inprogress(self, event: Dict[str, Any]) -> None:

logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}")

if self.use_local_cache:
cached_record = self._retrieve_from_cache(idempotency_key=data_record.idempotency_key)
if cached_record:
raise IdempotencyItemAlreadyExistsError
if self._retrieve_from_cache(idempotency_key=data_record.idempotency_key):
raise IdempotencyItemAlreadyExistsError

self._put_record(data_record)

# This has to come after _put_record. If _put_record call raises ItemAlreadyExists we shouldn't populate the
# cache with an "INPROGRESS" record as we don't know the status in the data store at this point.
if self.use_local_cache:
self._save_to_cache(data_record)

def delete_record(self, event: Dict[str, Any], exception: Exception):
"""
Delete record from the persistence store
Expand All @@ -329,8 +344,7 @@ def delete_record(self, event: Dict[str, Any], exception: Exception):
)
self._delete_record(data_record)

if self.use_local_cache:
self._delete_from_cache(data_record.idempotency_key)
self._delete_from_cache(data_record.idempotency_key)

def get_record(self, event: Dict[str, Any]) -> DataRecord:
"""
Expand All @@ -356,17 +370,15 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord:

idempotency_key = self._get_hashed_idempotency_key(event)

if self.use_local_cache:
cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}")
self._validate_payload(event, cached_record)
return cached_record
cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}")
self._validate_payload(event, cached_record)
return cached_record

record = self._get_record(idempotency_key)

if self.use_local_cache:
self._save_to_cache(data_record=record)
self._save_to_cache(data_record=record)

self._validate_payload(event, record)
return record
Expand Down
57 changes: 30 additions & 27 deletions docs/utilities/idempotency.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ title: Idempotency
description: Utility
---

This utility provides a simple solution to convert your Lambda functions into idempotent operations which are safe to
retry.
!!! attention
**This utility is currently in beta**. Please open an [issue in GitHub](https://github.com/awslabs/aws-lambda-powertools-python/issues/new/choose) for any bugs or feature requests.

The idempotency utility provides a simple solution to convert your Lambda functions into idempotent operations which
are safe to retry.

## Terminology

Expand All @@ -31,31 +34,31 @@ storage layer, so you'll need to create a table first.
> Example using AWS Serverless Application Model (SAM)
=== "template.yml"
```yaml
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Properties:
Runtime: python3.8
...
Policies:
- DynamoDBCrudPolicy:
TableName: !Ref IdempotencyTable
IdempotencyTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: id
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: id
KeyType: HASH
TableName: "IdempotencyTable"
TimeToLiveSpecification:
AttributeName: expiration
Enabled: true
```
```yaml
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Properties:
Runtime: python3.8
...
Policies:
- DynamoDBCrudPolicy:
TableName: !Ref IdempotencyTable
IdempotencyTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: id
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: id
KeyType: HASH
TableName: "IdempotencyTable"
TimeToLiveSpecification:
AttributeName: expiration
Enabled: true
```

!!! warning
When using this utility with DynamoDB, your lambda responses must always be smaller than 400kb. Larger items cannot
Expand Down
55 changes: 44 additions & 11 deletions tests/functional/idempotency/test_idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_idempotent_lambda_in_progress_with_cache(

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", ddb_response, expected_params)

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params))

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params))
stubber.activate()

@idempotent(persistence_store=persistence_store)
Expand All @@ -151,11 +157,8 @@ def lambda_handler(event, context):
assert retrieve_from_cache_spy.call_count == 2 * loops
retrieve_from_cache_spy.assert_called_with(idempotency_key=hashed_idempotency_key)

assert save_to_cache_spy.call_count == 1
first_call_args_data_record = save_to_cache_spy.call_args_list[0].kwargs["data_record"]
assert first_call_args_data_record.idempotency_key == hashed_idempotency_key
assert first_call_args_data_record.status == "INPROGRESS"
assert persistence_store._cache.get(hashed_idempotency_key)
save_to_cache_spy.assert_called()
assert persistence_store._cache.get(hashed_idempotency_key) is None

stubber.assert_no_pending_responses()
stubber.deactivate()
Expand Down Expand Up @@ -223,12 +226,10 @@ def lambda_handler(event, context):

lambda_handler(lambda_apigw_event, {})

assert retrieve_from_cache_spy.call_count == 1
assert save_to_cache_spy.call_count == 2
first_call_args, second_call_args = save_to_cache_spy.call_args_list
assert first_call_args.args[0].status == "INPROGRESS"
assert second_call_args.args[0].status == "COMPLETED"
assert persistence_store._cache.get(hashed_idempotency_key)
retrieve_from_cache_spy.assert_called_once()
save_to_cache_spy.assert_called_once()
assert save_to_cache_spy.call_args[0][0].status == "COMPLETED"
assert persistence_store._cache.get(hashed_idempotency_key).status == "COMPLETED"

# This lambda call should not call AWS API
lambda_handler(lambda_apigw_event, {})
Expand Down Expand Up @@ -594,3 +595,35 @@ def test_data_record_invalid_status_value():
_ = data_record.status

assert e.value.args[0] == "UNSUPPORTED_STATUS"


@pytest.mark.parametrize("persistence_store", [{"use_local_cache": True}], indirect=True)
def test_in_progress_never_saved_to_cache(persistence_store):
# GIVEN a data record with status "INPROGRESS"
# and persistence_store has use_local_cache = True
data_record = DataRecord("key", status="INPROGRESS")

# WHEN saving to local cache
persistence_store._save_to_cache(data_record)

# THEN don't save to local cache
assert persistence_store._cache.get("key") is None


@pytest.mark.parametrize("persistence_store", [{"use_local_cache": False}], indirect=True)
def test_user_local_disabled(persistence_store):
# GIVEN a persistence_store with use_local_cache = False

# WHEN calling any local cache options
data_record = DataRecord("key", status="COMPLETED")
try:
persistence_store._save_to_cache(data_record)
cache_value = persistence_store._retrieve_from_cache("key")
assert cache_value is None
persistence_store._delete_from_cache("key")
except AttributeError as e:
pytest.fail(f"AttributeError should not be raised: {e}")

# THEN raise AttributeError
# AND don't have a _cache attribute
assert not hasattr("persistence_store", "_cache")
10 changes: 10 additions & 0 deletions tests/unit/test_json_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import decimal
import json

import pytest

from aws_lambda_powertools.shared.json_encoder import Encoder


Expand All @@ -12,3 +14,11 @@ def test_jsonencode_decimal():
def test_jsonencode_decimal_nan():
result = json.dumps({"val": decimal.Decimal("NaN")}, cls=Encoder)
assert result == '{"val": NaN}'


def test_jsonencode_calls_default():
class CustomClass:
pass

with pytest.raises(TypeError):
json.dumps({"val": CustomClass()}, cls=Encoder)

0 comments on commit f1a8832

Please sign in to comment.