Skip to content

Commit

Permalink
Refactoring a little bit of the task code to wrap my head around it
Browse files Browse the repository at this point in the history
  • Loading branch information
Rich Jones committed Mar 28, 2017
1 parent e58b179 commit f1788d4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 54 deletions.
89 changes: 62 additions & 27 deletions zappa/async.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import os
import json
import importlib
import inspect
import boto3


"""
Zappa Async Tasks
Example:
```
from zappa.async import task
@task(service='sns')
Expand All @@ -16,31 +12,57 @@ def my_async_func(*args, **kwargs):
res = my_async_func.async(*args, **kwargs)
if res.sent:
print('It was dispatched! Who knows what the function result will be!')
```
For sns, you can also pass an `arn` argument to task() which will specify which SNS path to send it to.
For SNS, you can also pass an `arn` argument to task() which will specify which SNS path to send it to.
Without service='sns', the default service is 'lambda' which will call the method in an asynchronous
Without `service='sns'`, the default service is 'lambda' which will call the method in an asynchronous
lambda call.
The following restrictions apply:
* func must have a clean import path -- i.e. no closures, lambdas, or methods.
* args and kwargs must be json-serializable.
* The json-serialized form must be within the size limits for Lambda (128K) or SNS (256K) events.
* function must have a clean import path -- i.e. no closures, lambdas, or methods.
* args and kwargs must be JSON-serializable.
* The JSON-serialized form must be within the size limits for Lambda (128K) or SNS (256K) events.
"""

import os
import json
import importlib
import inspect
import boto3

from util import get_topic_name

AWS_REGION = os.environ.get('AWS_REGION')
AWS_LAMBDA_FUNCTION_NAME = os.environ.get('AWS_LAMBDA_FUNCTION_NAME')

# Declare these here so they're kept warm.
LAMBDA_CLIENT = boto3.client('lambda')
SNS_CLIENT = boto3.client('sns')
STS_CLIENT = boto3.client('sts')

##
# Response and Exception classes
##

class AsyncException(Exception):
""" Simple exception class for async tasks. """
pass

class LambdaAsyncResponse(object):
"""
Base Response Dispatcher class
Can be used directly or subclassed if the method to send the message is changed.
"""
def __init__(self, **kwargs):
self.client = boto3.client('lambda')
""" """
self.client = LAMBDA_CLIENT

def send(self, task_path, args, kwargs):
"""
Create the message object and pass it to the actual sender.
"""
message = {
'task_path': task_path,
'args': args,
Expand All @@ -50,44 +72,54 @@ def send(self, task_path, args, kwargs):
return self

def _send(self, message):
"""
Given a message, directly invoke the lamdba function for this task.
"""
message['command'] = 'zappa.async.route_lambda_task'
payload = json.dumps(message).encode('utf-8')
if len(payload) > 128000:
raise AsyncException("Payload too large for async Lambda call")
self.response = self.client.invoke(
FunctionName=AWS_LAMBDA_FUNCTION_NAME,
InvocationType='Event', #makes the call async
Payload=payload)

FunctionName=AWS_LAMBDA_FUNCTION_NAME,
InvocationType='Event', #makes the call async
Payload=payload
)
self.sent = (response.get('StatusCode', 0) == 202)


class SnsAsyncResponse(LambdaAsyncResponse):
"""
Send a SNS message to a specified SNS topic
Serialise the func path and arguments
"""
def __init__(self, **kwargs):
self.client = boto3.client('sns')
self.client = SNS_CLIENT
if kwargs.get('arn'):
self.arn = kwargs.get('arn')
else:
stsclient = boto3.client('sts')
AWS_ACCOUNT_ID = stsclient.get_caller_identity()['Account']
self.arn = 'arn:aws:sns:{region}:{account}:{lambda_name}-zappa-async'.format(
region=AWS_REGION, account=AWS_ACCOUNT_ID,
lambda_name=AWS_LAMBDA_FUNCTION_NAME
)
sts_client = STS_CLIENT
AWS_ACCOUNT_ID = sts_client.get_caller_identity()['Account']
self.arn = 'arn:aws:sns:{region}:{account}:{topic_name}'.format(
region=AWS_REGION,
account=AWS_ACCOUNT_ID,
topic_name=get_topic_name(AWS_LAMBDA_FUNCTION_NAME)
)

def _send(self, message):
"""
Given a message, publish to this topic.
"""
payload = json.dumps(message)
if len(payload) > 256000:
raise AsyncException("Payload too large for SNS")
self.response = client.publish(
TargetArn=self.arn, Message=payload
)
TargetArn=self.arn,
Message=payload
)
self.sent = self.response.get('MessageId')

##
# Aync routers and utility functions
##

ASYNC_CLASSES = {
'lambda': LambdaAsyncResponse,
Expand All @@ -106,6 +138,9 @@ def _import_and_get_task(task_path):


def _get_func_task_path(func):
"""
Format the modular task path for a function via inspection.
"""
module_path = inspect.getmodule(func).__name__
task_path = '{module_path}.{func_name}'.format(
module_path=module_path,
Expand Down
8 changes: 8 additions & 0 deletions zappa/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def detect_flask_apps():

return matches

##
# Async Tasks
##

def get_topic_name(self, lambda_name):
""" Topic name generation """
return '%s-zappa-async' % lambda_name

##
# Event sources / Kappa
##
Expand Down
64 changes: 37 additions & 27 deletions zappa/zappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tqdm import tqdm

# Zappa imports
from util import copytree, add_event_source, remove_event_source, human_size
from util import copytree, add_event_source, remove_event_source, human_size, get_topic_name

##
# Logging Config
Expand Down Expand Up @@ -1712,6 +1712,31 @@ def create_iam_roles(self):

return self.credentials_arn, updated

def _clear_policy(self, lambda_name):
"""
Remove obsolete policy statements to prevent policy from bloating over the limit after repeated updates.
"""
try:
policy_response = self.lambda_client.get_policy(
FunctionName=lambda_name
)
if policy_response['ResponseMetadata']['HTTPStatusCode'] == 200:
statement = json.loads(policy_response['Policy'])['Statement']
for s in statement:
delete_response = self.lambda_client.remove_permission(
FunctionName=lambda_name,
StatementId=s['Sid']
)
if delete_response['ResponseMetadata']['HTTPStatusCode'] != 204:
logger.error('Failed to delete an obsolete policy statement: {}'.format())
else:
logger.debug('Failed to load Lambda function policy: {}'.format(policy_response))
except ClientError as e:
if e.message.find('ResourceNotFoundException') > -1:
logger.debug('No policy found, must be first run.')
else:
logger.error('Unexpected client error {}'.format(e.message))

##
# CloudWatch Events
##
Expand Down Expand Up @@ -1942,33 +1967,15 @@ def unschedule_events(self, events, lambda_arn=None, lambda_name=None, excluded_
)
print("Removed event " + name + " (" + str(event_source['events']) + ").")

def _clear_policy(self, lambda_name):
"""
Remove obsolete policy statements to prevent policy from bloating over the limit after repeated updates.
"""
try:
policy_response = self.lambda_client.get_policy(
FunctionName=lambda_name
)
if policy_response['ResponseMetadata']['HTTPStatusCode'] == 200:
statement = json.loads(policy_response['Policy'])['Statement']
for s in statement:
delete_response = self.lambda_client.remove_permission(
FunctionName=lambda_name,
StatementId=s['Sid']
)
if delete_response['ResponseMetadata']['HTTPStatusCode'] != 204:
logger.error('Failed to delete an obsolete policy statement: {}'.format())
else:
logger.debug('Failed to load Lambda function policy: {}'.format(policy_response))
except ClientError as e:
if e.message.find('ResourceNotFoundException') > -1:
logger.debug('No policy found, must be first run.')
else:
logger.error('Unexpected client error {}'.format(e.message))
###
# Async / SNS
##

def create_async_sns_topic(self, lambda_name, lambda_arn):
topic_name = '%s-zappa-async' % lambda_name
"""
Create the SNS-based async topic.
"""
topic_name = get_topic_name(lambda_name)
# Create SNS topic
topic_arn = self.sns_client.create_topic(
Name=topic_name)['TopicArn']
Expand Down Expand Up @@ -1997,7 +2004,10 @@ def create_async_sns_topic(self, lambda_name, lambda_arn):
return topic_arn

def remove_async_sns_topic(self, lambda_name):
topic_name = '%s-zappa-async' % lambda_name
"""
Remove the async SNS topic.
"""
topic_name = get_topic_name(lambda_name)
removed_arns = []
for sub in self.sns_client.list_subscriptions()['Subscriptions']:
if topic_name in sub['TopicArn']:
Expand Down

0 comments on commit f1788d4

Please sign in to comment.