Skip to content

Commit

Permalink
Invoke downloader via SQS
Browse files Browse the repository at this point in the history
  • Loading branch information
Austin Byers committed Aug 7, 2018
1 parent 1f7a59c commit f918fdf
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 59 deletions.
19 changes: 9 additions & 10 deletions lambda_functions/analyzer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,21 @@ def analyze_lambda_handler(event: Dict[str, Any], lambda_context: Any) -> Dict[s
Args:
event: SQS message batch - each message body is a JSON-encoded S3 notification - {
"Records": [
'Records': [
{
"body": json.dumps({
"Records": [
"s3": {
"bucket": {
"name": "..."
'body': json.dumps({
'Records': [
's3': {
'bucket': {
'name': '...'
},
"object": {
"key": "..."
'object': {
'key': '...'
}
}
]
}),
"messageId": "...",
"receiptHandle": "..."
'messageId': '...'
}
]
}
Expand Down
54 changes: 17 additions & 37 deletions lambda_functions/downloader/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# ENCRYPTED_CARBON_BLACK_API_TOKEN: API token, encrypted with KMS.
# TARGET_S3_BUCKET: Name of the S3 bucket in which to save the copied binary.
import base64
import collections
import json
import logging
import os
import shutil
import subprocess
import tempfile
from typing import Any, Dict, Generator, List
from typing import Any, Dict, Generator, List, Tuple
import zipfile

import boto3
Expand All @@ -34,18 +33,14 @@
url=os.environ['CARBON_BLACK_URL'], token=DECRYPTED_TOKEN)
CLOUDWATCH = boto3.client('cloudwatch')
S3_BUCKET = boto3.resource('s3').Bucket(os.environ['TARGET_S3_BUCKET'])
SQS = boto3.resource('sqs')

# The download invocation event is parsed into a tuple with MD5 and a Receipt
DownloadRecord = collections.namedtuple('DownloadRecord', ['md5', 'sqs_receipt', 'receive_count'])


def _iter_download_records(event: Any) -> Generator[DownloadRecord, None, None]:
"""Generate DownloadRecords from the invocation event."""
for message in event['messages']:
def _iter_download_records(event: Any) -> Generator[Tuple[str, int], None, None]:
"""Yield (md5, receive_count) from the invocation event."""
for message in event['Records']:
try:
md5 = json.loads(message['body'])['md5']
yield DownloadRecord(md5, message['receipt'], message['receive_count'])
yield md5, int(message['attributes']['ApproximateReceiveCount'])
except (json.JSONDecodeError, KeyError, TypeError):
LOGGER.exception('Skipping invalid SQS record: %s', message)
continue
Expand Down Expand Up @@ -123,16 +118,6 @@ def _process_md5(md5: str) -> bool:
subprocess.check_call(['shred', '--remove', download_path])


def _delete_sqs_messages(queue_url: str, receipts: List[str], ) -> None:
"""Mark a batch of SQS receipts as completed (removing them from the queue)."""
LOGGER.info('Deleting %d SQS receipt(s)', len(receipts))
SQS.Queue(queue_url).delete_messages(
Entries=[
{'Id': str(index), 'ReceiptHandle': receipt} for index, receipt in enumerate(receipts)
]
)


def _publish_metrics(receive_counts: List[int]) -> None:
"""Send a statistic summary of receive counts."""
LOGGER.info('Sending ReceiveCount metrics')
Expand All @@ -154,30 +139,25 @@ def download_lambda_handler(event: Dict[str, Any], _: Any) -> None:
"""Lambda function entry point - copy a binary from CarbonBlack into the BinaryAlert S3 bucket.
Args:
event: SQS message batch sent by the dispatcher: {
'messages': [
event: SQS message batch - {
"Records": [
{
'body': (str) '{"md5": "FILE_MD5"}',
'receipt': (str) SQS message receipt handle,
'receive_count': (int) Approximate number of times this has been received
},
...
],
'queue_url': (str) SQS queue url from which the message originated
'attributes': {
'ApproximateReceiveCount': 1
},
'body": '{"md5": "FILE_MD5"}',
'messageId': '...'
}
]
}
_: Unused Lambda context
"""
receipts_to_delete = [] # SQS receipts which can be deleted.
receive_counts = [] # A list of message receive counts.

for record in _iter_download_records(event):
if _process_md5(record.md5):
for md5, receive_count in _iter_download_records(event):
if _process_md5(md5):
# File was copied successfully - the receipt can be deleted
receipts_to_delete.append(record.sqs_receipt)
receive_counts.append(record.receive_count)

if receipts_to_delete:
_delete_sqs_messages(event['queue_url'], receipts_to_delete)
receive_counts.append(receive_count)

if receive_counts:
_publish_metrics(receive_counts)
7 changes: 7 additions & 0 deletions terraform/lambda.tf
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ module "binaryalert_downloader" {
tagged_name = "${var.tagged_name}"
alarm_sns_arns = ["${aws_sns_topic.metric_alarms.arn}"]
}

// Invoke downloader Lambda from downloader SQS queue.
resource "aws_lambda_event_source_mapping" "downloader_via_sqs" {
batch_size = "${var.download_queue_batch_size}"
event_source_arn = "${aws_sqs_queue.downloader_queue.arn}"
function_name = "${module.binaryalert_downloader.alias_arn}"
}
18 changes: 14 additions & 4 deletions terraform/lambda_iam.tf
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ data "aws_iam_policy_document" "binaryalert_downloader_policy" {
"kms:GenerateDataKey",
]

resources = ["${aws_kms_key.sse_s3.arn}"]
resources = [
"${aws_kms_key.sse_s3.arn}",
"${aws_kms_key.sse_sqs.arn}",
]
}

statement {
Expand All @@ -163,9 +166,16 @@ data "aws_iam_policy_document" "binaryalert_downloader_policy" {
}

statement {
sid = "DeleteFromDownloadQueue"
effect = "Allow"
actions = ["sqs:DeleteMessage"]
sid = "ProcessSQSMessages"
effect = "Allow"

actions = [
"sqs:ChangeMessageVisibility",
"sqs:DeleteMessage",
"sqs:GetQueueAttributes",
"sqs:ReceiveMessage",
]

resources = ["${aws_sqs_queue.downloader_queue.arn}"]
}
}
Expand Down
3 changes: 2 additions & 1 deletion terraform/terraform.tfvars
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ force_destroy = true


// ##### SQS #####
// Maximum number of messages that will be received by each invocation of the analyzer Lambda.
// Maximum number of messages that will be received by each invocation of the respective function.
analyze_queue_batch_size = 10
download_queue_batch_size = 1

// If an SQS message is not deleted (successfully processed) after the max number of receive
// attempts, the message is delivered to the SQS dead-letter queue.
Expand Down
1 change: 1 addition & 0 deletions terraform/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ variable "lambda_download_timeout_sec" {}
variable "force_destroy" {}

variable "analyze_queue_batch_size" {}
variable "download_queue_batch_size" {}
variable "download_queue_max_receives" {}
13 changes: 6 additions & 7 deletions tests/lambda_functions/downloader/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def setUp(self):

# Create the test event.
self.event = {
'messages': [
'Records': [
{
'body': '{"md5": "ABC123"}',
'receipt': 'TEST-RECEIPT',
'receive_count': 1
'attributes': {
'ApproximateReceiveCount': 1
},
'body': '{"md5": "ABC123"}'
}
],
'queue_url': 'TEST-QUEUE-URL'
]
}

# Mock out cbapi and import the file under test.
Expand Down Expand Up @@ -106,6 +106,5 @@ def test_download_from_carbon_black(self):
mock.call.info(
'Downloading %s to %s', self._binary.webui_link, mock.ANY),
mock.call.info('Uploading to S3 with key %s', mock.ANY),
mock.call.info('Deleting %d SQS receipt(s)', 1),
mock.call.info('Sending ReceiveCount metrics')
])

0 comments on commit f918fdf

Please sign in to comment.