From 7ce3998b11e5dd69cc69929ced41afbf68c16c4f Mon Sep 17 00:00:00 2001 From: Georgi Date: Wed, 27 Mar 2024 18:20:51 +0100 Subject: [PATCH] Fix a bug in s3 handler to handle SNS records --- .../steps/handlers/s3_handler.py | 11 +-- aws/logs_monitoring/tests/test_s3_handler.py | 67 +++++++++++++++++++ 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/aws/logs_monitoring/steps/handlers/s3_handler.py b/aws/logs_monitoring/steps/handlers/s3_handler.py index 512162208..eb0d10cf0 100644 --- a/aws/logs_monitoring/steps/handlers/s3_handler.py +++ b/aws/logs_monitoring/steps/handlers/s3_handler.py @@ -44,12 +44,13 @@ def s3_handler(event, context, metadata, cache_layer): # Get the S3 client s3 = get_s3_client() # if this is a S3 event carried in a SNS message, extract it and override the event - first_record = event.get("Records")[0] - if sns := first_record.get("Sns"): - event = json.loads(sns.get("Message")) + if "Sns" in event.get("Records")[0]: + event = json.loads(event.get("Records")[0].get("Sns").get("Message")) # Get the object from the event and show its content type - bucket = first_record.get("s3").get("bucket").get("name") - key = urllib.parse.unquote_plus(first_record.get("s3").get("object").get("key")) + bucket = event.get("Records")[0].get("s3").get("bucket").get("name") + key = urllib.parse.unquote_plus( + event.get("Records")[0].get("s3").get("object").get("key") + ) source = set_source(event, metadata, bucket, key) # Add Service tag add_service_tag(metadata) diff --git a/aws/logs_monitoring/tests/test_s3_handler.py b/aws/logs_monitoring/tests/test_s3_handler.py index d5ff0f487..ca9834fb1 100644 --- a/aws/logs_monitoring/tests/test_s3_handler.py +++ b/aws/logs_monitoring/tests/test_s3_handler.py @@ -1,7 +1,9 @@ import gzip import unittest +from unittest.mock import MagicMock, patch from approvaltests.combination_approvals import verify_all_combinations from steps.handlers.s3_handler import ( + s3_handler, parse_service_arn, get_partition_from_region, get_structured_lines_for_s3_handler, @@ -9,6 +11,12 @@ class TestS3EventsHandler(unittest.TestCase): + class Context: + function_version = 0 + invoked_function_arn = "invoked_function_arn" + function_name = "function_name" + memory_limit_in_mb = "10" + def parse_lines(self, data, key, source): bucket = "my-bucket" gzip_data = gzip.compress(bytes(data, "utf-8")) @@ -59,6 +67,65 @@ def test_get_partition_from_region(self): self.assertEqual(get_partition_from_region("cn-north-1"), "aws-cn") self.assertEqual(get_partition_from_region(None), "aws") + @patch("steps.handlers.s3_handler.extract_data") + @patch("steps.handlers.s3_handler.get_s3_client") + def test_s3_handler(self, mock_s3_client, extract_data): + event = { + "Records": [ + { + "s3": { + "bucket": {"name": "my-bucket"}, + "object": {"key": "my-key"}, + } + } + ] + } + context = self.Context() + metadata = {"ddtags": ""} + extract_data.side_effect = [("data".encode("utf-8"))] + cache_layer = MagicMock() + structured_lines = list(s3_handler(event, context, metadata, cache_layer)) + self.assertEqual( + structured_lines, + [ + { + "aws": {"s3": {"bucket": "my-bucket", "key": "my-key"}}, + "message": "data", + } + ], + ) + self.assertEqual(metadata["ddsource"], "s3") + self.assertEqual(metadata["host"], "arn:aws:s3:::my-bucket") + + @patch("steps.handlers.s3_handler.extract_data") + @patch("steps.handlers.s3_handler.get_s3_client") + def test_s3_handler_with_sns(self, mock_s3_client, extract_data): + event = { + "Records": [ + { + "Sns": { + "Message": '{"Records": [{"s3": {"bucket": {"name": "my-bucket"}, "object": {"key": "sns-my-key"}}}]}' + } + } + ] + } + context = self.Context() + metadata = {"ddtags": ""} + extract_data.side_effect = [("data".encode("utf-8"))] + cache_layer = MagicMock() + structured_lines = list(s3_handler(event, context, metadata, cache_layer)) + self.assertEqual( + structured_lines, + [ + { + "aws": {"s3": {"bucket": "my-bucket", "key": "sns-my-key"}}, + "message": "data", + } + ], + ) + self.assertEqual(metadata["ddsource"], "s3") + self.assertEqual(metadata["host"], "arn:aws:s3:::my-bucket") + class TestParseServiceArn(unittest.TestCase): def test_elb_s3_key_invalid(self):