From 44fe7d0f8a5b28e3a0403ad1ab90f70186ea5b9e Mon Sep 17 00:00:00 2001 From: Ryxias Date: Wed, 30 Jan 2019 12:47:11 -0800 Subject: [PATCH] Separates CredentialProvider from OutputDispatcher (#875) * Initial commit * Now with tests * Touchups * Remove extraneous static method * DRYs out some code related to S3 secret bucket name * Proof-of-concept of mocking CredentialProvider * Fix some pylints * Removes output_cred_name() to consolidate functionality. Updates all tests * pylint * Refactors Drivers for OutputCredentialsProvider (#878) * Working? Draft of new driver-based credentials storage * Higher quality refactor; needs tests * Kinks ironed out with good tests this time * First maybe working try * Adds lots of tests for the Drivers * Add more tests. Not final; still need to remove deprecated methods * Rename method to reduce confusion * Remove deprecated method load_encrypted_credentials_from_s3 * Removes deprecated method get_local_credentials_temp_dir * Remove deprecated method get_formatted_output_credentials_name * Removes deprecated method kms_decrypt * Removes extraneous imports * Extract globally injected REGION so the handler can be implemented properly * Pylint` * Pylint is my nemesis * Remove extraneous method * Use default_config for boto clients * Code coverage. PR feedback. * Add missing __init__.py file causing poor code coverage * Fixes the tests to get past pylint garbage * Change account_id to a property * Formatting * (fixup) PR Feedback * (fixup) Disentangles streamalert cli helper with stream_alert credentials provider. * (fixup) Deprecate old methods * pylint * Removes deprecated code; adds tests for aws_api_client; DRYs out test code --- stream_alert/alert_processor/outputs/aws.py | 4 + .../outputs/credentials/__init__.py | 0 .../outputs/credentials/provider.py | 685 ++++++++++++++++ .../alert_processor/outputs/output_base.py | 130 +-- stream_alert/shared/helpers/aws_api_client.py | 184 +++++ stream_alert_cli/outputs/handler.py | 16 +- stream_alert_cli/outputs/helpers.py | 71 -- stream_alert_cli/terraform/generate.py | 1 + tests/unit/helpers/aws_mocks.py | 9 +- .../stream_alert_alert_processor/helpers.py | 18 +- .../test_outputs/credentials/__init__.py | 0 .../test_outputs/credentials/test_provider.py | 742 ++++++++++++++++++ .../test_outputs/test_carbonblack.py | 38 +- .../test_outputs/test_github.py | 47 +- .../test_outputs/test_jira.py | 36 +- .../test_outputs/test_komand.py | 36 +- .../test_outputs/test_output_base.py | 64 +- .../test_outputs/test_pagerduty.py | 81 +- .../test_outputs/test_phantom.py | 40 +- .../test_outputs/test_slack.py | 34 +- tests/unit/stream_alert_cli/test_outputs.py | 82 -- .../test_aws_api_client.py | 84 ++ 22 files changed, 1842 insertions(+), 560 deletions(-) create mode 100644 stream_alert/alert_processor/outputs/credentials/__init__.py create mode 100644 stream_alert/alert_processor/outputs/credentials/provider.py create mode 100644 stream_alert/shared/helpers/aws_api_client.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/credentials/__init__.py create mode 100644 tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py delete mode 100644 tests/unit/stream_alert_cli/test_outputs.py create mode 100644 tests/unit/stream_alert_shared/test_aws_api_client.py diff --git a/stream_alert/alert_processor/outputs/aws.py b/stream_alert/alert_processor/outputs/aws.py index 13602a732..37ccbc863 100644 --- a/stream_alert/alert_processor/outputs/aws.py +++ b/stream_alert/alert_processor/outputs/aws.py @@ -338,6 +338,10 @@ def _dispatch(self, alert, descriptor): return True + @property + def account_id(self): + return self._credentials_provider.get_aws_account_id() + @StreamAlertOutput class SQSOutput(AWSOutput): diff --git a/stream_alert/alert_processor/outputs/credentials/__init__.py b/stream_alert/alert_processor/outputs/credentials/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/stream_alert/alert_processor/outputs/credentials/provider.py b/stream_alert/alert_processor/outputs/credentials/provider.py new file mode 100644 index 000000000..ec2203452 --- /dev/null +++ b/stream_alert/alert_processor/outputs/credentials/provider.py @@ -0,0 +1,685 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import json +import os +import shutil +import tempfile +from abc import abstractmethod + +from botocore.exceptions import ClientError + +from stream_alert.shared.helpers.aws_api_client import AwsKms, AwsS3 +from stream_alert.shared.logger import get_logger + +LOGGER = get_logger(__name__) + + +class OutputCredentialsProvider(object): + """Loads credentials that are housed on AWS S3, or cached locally. + + Helper service to OutputDispatcher. + + OutputDispatcher implementations may require credentials to authenticate with an external + gateway. All credentials for OutputDispatchers are to be stored in a single bucket on AWS S3 + and are encrypted with AWS KMS. When alerts are dispatched via OutputDispatchers, these + encrypted credentials are downloaded and cached locally on the filesystem. Then, AWS KMS is + used to decrypt the credentials when in use. + + Public methods: + load_credentials: Returns a dict of the credentials requested + get_local_credentials_temp_dir(): Returns full path to a temporary directory where all + encrypted credentials are cached. + """ + + def __init__(self, + service_name, + config=None, + defaults=None, + region=None, + prefix=None, + aws_account_id=None): + self._service_name = service_name + + # Region: Check constructor args first, then config + self._region = config['global']['account']['region'] if region is None else region + + # Prefix: Check constructor args first, then ENV, then config + self._prefix = self._calculate_prefix(prefix, config) + + # Account Id: Check constructor args first, then ENV, then config + self._account_id = self._calculate_account_id(aws_account_id, config) + + self._defaults = defaults if defaults else {} + + # Drivers are strategies utilized by this class for fetching credentials from various + # locations on disk or remotely + self._drivers = [] # type: list[CredentialsProvidingDriver] + self._core_driver = None # type: S3Driver + self._setup_drivers() + + @staticmethod + def _calculate_prefix(given_prefix, config): + if given_prefix is not None: + return given_prefix + + if 'STREAMALERT_PREFIX' in os.environ: + return os.environ['STREAMALERT_PREFIX'] + + return config['global']['account']['prefix'] + + @staticmethod + def _calculate_account_id(given_account_id, config): + if given_account_id is not None: + return given_account_id + + if 'AWS_ACCOUNT_ID' in os.environ: + return os.environ['AWS_ACCOUNT_ID'] + + return config['global']['account']['aws_account_id'] + + def _setup_drivers(self): + """Initializes all drivers. + + The Drivers are sequentially checked in the order they are appended to the driver list. + """ + + # Ephemeral driver + ep_driver = EphemeralUnencryptedDriver(self._service_name) + self._drivers.append(ep_driver) + + # Fall back onto downloading encrypted credentials from S3 + s3_driver = S3Driver(self._prefix, self._service_name, self._region, cache_driver=ep_driver) + self._core_driver = s3_driver + self._drivers.append(s3_driver) + + def save_credentials(self, descriptor, kms_key_alias, props): + """Saves given credentials into S3. + + Args: + descriptor (str): OutputDispatcher descriptor + kms_key_alias (str): KMS Key alias provided by configs + props (Dict(str, OutputProperty)): A dict containing strings mapped to OutputProperty + objects. + + Returns: + bool: True is credentials successfully saved. False otherwise. + """ + + creds = {name: prop.value + for (name, prop) in props.iteritems() if prop.cred_requirement} + + credentials = Credentials(creds, False, self._region) + return self._core_driver.save_credentials_into_s3(descriptor, credentials, kms_key_alias) + + def load_credentials(self, descriptor): + """Loads credentials from the drivers. + + Args: + descriptor (str): unique identifier used to look up these credentials + + Returns: + dict: the loaded credential info needed for sending alerts to this service + or None if nothing gets loaded + """ + credentials = None + for driver in self._drivers: + if driver.has_credentials(descriptor): + credentials = driver.load_credentials(descriptor) + if credentials: + break + + if not credentials: + LOGGER.error( + 'All drivers failed to retrieve credentials for [%s.%s]', + self._service_name, + descriptor + ) + return None + elif credentials.is_encrypted(): + decrypted_creds = credentials.get_data_kms_decrypted() + else: + decrypted_creds = credentials.data() + + creds_dict = json.loads(decrypted_creds) + + # Add any of the hard-coded default output props to this dict (ie: url) + defaults = self._defaults + if defaults: + creds_dict.update(defaults) + + return creds_dict + + def get_aws_account_id(self): + """Returns the AWS account ID""" + return self._account_id + + +class Credentials(object): + """Encapsulation for a set of credentials. + + When storing to or loading from a Driver, the raw credentials data may or may not be encrypted + (e.g. when writing to disk, we should always keep it encrypted). To allow Credentials to be + passed from Driver to Driver without excessive calls to KMS.encrypt/decrypt, this "data" in + the Credentials can be either encrypted or not. It is up to the code that constructs the + Credentials object to know which. + + When retrieving the raw, unencrypted data from a Credentials object, use the following code: + + if credentials.is_encrypted(): + return credentials.get_data_kms_decrypted() + else: + return credentials.data() + """ + + def __init__(self, data, is_encrypted=False, region=None): + """ + Args: + data (object|string): A json serializable object, or a string. + is_encrypted (bool): Pass True if the input data is encrypted with KMS. False otherwise. + region (str): AWS Region. Only required if is_encrypted=True. + """ + self._data = data + self._is_encrypted = is_encrypted + self._region = region if is_encrypted else None # No use for region if unencrypted + + def is_encrypted(self): + """True if this Credentials object is encrypted. False otherwise. + + Returns: + bool + """ + return self._is_encrypted + + def data(self): + """ + Returns: + str: The raw text data of this Credentials object, encrypted or not. This may be + unusable if encrypted, but can be passed to another Driver for storage. + """ + return self._data + + def get_data_kms_decrypted(self): + """Returns the data of this Credentials objects, decrypted with KMS. + + This does not mutate the internals of this Credentials for safety. It simply returns + the decrypted payload. It is up to the called to safely manage the payload. + + Returns: + str|None: The decrypted payload of this Credentials object, if it is encrypted. If it is + not encrypted, then will return None and log an error. + """ + if not self._is_encrypted: + LOGGER.error('Cannot decrypt Credentials as they are already decrypted') + return None + + try: + return AwsKms.decrypt(self._data, region=self._region) + except ClientError: + LOGGER.exception('an error occurred during credentials decryption') + return None + + def encrypt(self, region, kms_key_alias): + """Encrypts the current Credentials. + + Calling this method will entirely change the internals of this Credentials object. + Subsequent calls to .data() will return encrypted data. + """ + if self.is_encrypted(): + return + + self._is_encrypted = True + if not self._data: + return + + creds_json = json.dumps(self._data, separators=(',', ':')) + self._region = region + self._data = AwsKms.encrypt(creds_json, region=self._region, key_alias=kms_key_alias) + + +class CredentialsProvidingDriver(object): + """Drivers encapsulate logic for loading credentials""" + + @abstractmethod + def load_credentials(self, descriptor): + """Loads the requested credentials into a new Credentials object. + + The behavior can be nondeterministic if has_credentials() is false. + + Args: + descriptor (string): Descriptor for the current output service + + Return: + Credentials|None: Returns None when loading fails. + """ + + @abstractmethod + def has_credentials(self, descriptor): + """Determines whether the current driver is capable of loading the requested credentials. + + Args: + descriptor (string): Descriptor for the current output service + + Return: + bool: True if this driver has the requested Credentials, False otherwise. + """ + + +class FileDescriptorProvider(object): + """Interface for Drivers capable of offering file-handles to aid in download of credentials.""" + + @abstractmethod + def offer_fileobj(self, descriptor): + """Offers a file-like object. + + The caller is expected to call this method in a with block, and this file-like object + is expected to automatically close. + + Returns: + file object + """ + + +class CredentialsCachingDriver(object): + """Interface for Drivers capable of being used as a caching layer to accelerate the speed + of credential loading.""" + + @abstractmethod + def save_credentials(self, descriptor, credentials): + """Saves the given credentials. + + On a subsequent call of load_credentials(), the same credentials will be loaded. + + Args: + descriptor (str): OutputDispatcher descriptor + credentials (Credentials): The credentials object to save. Notably, certain drivers are + incapable of (or disallowed from) saving credentials that are unencrypted. + + Return: + bool: True if saving succeeds. False otherwise. + """ + + +def get_formatted_output_credentials_name(service_name, descriptor): + """Gives a unique name for credentials for the given service + descriptor. + + Args: + service_name (str): Service name on output class (i.e. "pagerduty", "demisto") + descriptor (str): Service destination (ie: slack channel, pd integration) + + Returns: + str: Formatted credential name (ie: slack/ryandchannel) + """ + cred_name = str(service_name) + + # should descriptor be enforced in all rules? + if descriptor: + cred_name = '{}/{}'.format(cred_name, descriptor) + + return cred_name + + +class S3Driver(CredentialsProvidingDriver): + """Driver for fetching credentials from AWS S3""" + + def __init__(self, prefix, service_name, region, file_driver=None, cache_driver=None): + """ + Args: + prefix (str): StreamAlert account prefix in configs + service_name (str): The service name for the OutputDispatcher using this + region (str): AWS Region + file_driver (FileDescriptorProvider|None): + Optional. When provided, the file_driver will be used to provide a File handle + for downloading the S3 credentials into. This can be useful if it is desired to + download the S3 credentials into a specific file for examination. + + If omitted, will defaulted to using SpooledTempfileDriver, which downloads + the S3 file into memory temporarily, and is cleaned up afterward. + + In all cases, the credentials file is downloaded and stored in the file-like + handle in ENCRYPTED FORM. + + cache_driver (CredentialsProvidingDriver|None): + Optional. When provided, the downloaded credentials will be cached in the given + driver. This is useful for reducing the number of S3/KMS calls and speeding up the + system. + + (!) Storage encryption of the credentials is determined by the driver. + """ + self._service_name = service_name + self._region = region + self._prefix = prefix + self._bucket = self.get_s3_secrets_bucket() + + self._file_driver = file_driver # type: FileDescriptorProvider + if not self._file_driver: + self._file_driver = SpooledTempfileDriver(self._service_name, self._region) + + self._cache_driver = cache_driver # type: CredentialsCachingDriver + + def load_credentials(self, descriptor): + """Loads credentials from AWS S3. + + Args: + descriptor (str): Service destination (ie: slack channel, pd integration) + + Returns: + Credentials: The loaded Credentials. None on failure + """ + try: + with self._file_driver.offer_fileobj(descriptor) as file_handle: + enc_creds = AwsS3.download_fileobj( + file_handle, + bucket=self._bucket, + region=self._region, + key=self.get_s3_key(descriptor) + ) + + credentials = Credentials(enc_creds, True, self._region) + if self._cache_driver: + self._cache_driver.save_credentials(descriptor, credentials) + + return credentials + except ClientError: + LOGGER.exception('credentials for \'%s\' could not be downloaded from S3', + get_formatted_output_credentials_name(self._service_name, descriptor)) + return None + + def has_credentials(self, descriptor): + """Always returns True, as S3 is the place where all encrypted credentials are + guaranteed to be cold-stored.""" + return True + + def save_credentials_into_s3(self, descriptor, credentials, kms_key_alias): + """Takes the given credentials, encrypts them, and saves them to AWS S3. + + Notably, this implementation is NOT for the CredentialsCachingDriver interface, as the + S3Driver is not a caching driver. + + Args: + descriptor (str): Descriptor of the current Output + credentials (Credentials): Credentials object to be saved into S3 + kms_key_alias (str): KMS key alias for streamalert secrets + + Returns: + bool: True on success, False otherwise. + """ + s3_key = get_formatted_output_credentials_name(self._service_name, descriptor) + + # Encrypt the creds and push them to S3 + if not credentials.is_encrypted(): + credentials.encrypt(self._region, kms_key_alias) + + encrypted_credentials = credentials.data() + if not encrypted_credentials: + return True + + try: + return AwsS3.put_object( + encrypted_credentials, + bucket=self._bucket, + key=s3_key, + region=self._region + ) + except ClientError: + LOGGER.exception( + 'An error occurred while sending credentials to S3 for key \'%s\' in bucket \'%s\'', + s3_key, + self._bucket + ) + return False + + def get_s3_key(self, descriptor): + """Returns an appropriate S3 bucket key for credentials relevant to this Output. + + Args: + descriptor (str): Descriptor of the current Output + + Returns: + string + """ + return get_formatted_output_credentials_name(self._service_name, descriptor) + + def get_s3_secrets_bucket(self): + """Returns an appropriate S3 bucket for all credentials relevant to this driver. + + Returns: + string + """ + return '{}.streamalert.secrets'.format(self._prefix) + + +class LocalFileDriver(CredentialsProvidingDriver, FileDescriptorProvider, CredentialsCachingDriver): + """Driver for fetching credentials that are saved locally on the filesystem.""" + + def __init__(self, region, service_name): + self._region = region + self._service_name = service_name + self._temp_dir = self.get_local_credentials_temp_dir() + + def load_credentials(self, descriptor): + local_cred_location = self.get_file_path(descriptor) + with open(local_cred_location, 'rb') as cred_file: + encrypted_credentials = cred_file.read() + + return Credentials(encrypted_credentials, True, self._region) + + def has_credentials(self, descriptor): + return os.path.exists(self.get_file_path(descriptor)) + + def save_credentials(self, descriptor, credentials): + if not credentials.is_encrypted(): + LOGGER.error('Error: Writing unencrypted credentials to disk is disallowed.') + return False + + with self.offer_fileobj(descriptor) as file_handle: + file_handle.write(credentials.data()) + return True + + @staticmethod + def clear(): + """Removes the local secrets directory that may be left from previous runs""" + secrets_dirtemp_dir = LocalFileDriver.get_local_credentials_temp_dir() + + # Check if the folder exists, and remove it if it does + if os.path.isdir(secrets_dirtemp_dir): + shutil.rmtree(secrets_dirtemp_dir) + + def offer_fileobj(self, descriptor): + """Opens a file-like object and returns it. + + If you use the return value in a `with` statement block then the file descriptor + will auto-close. + + Args: + descriptor (str): Descriptor of the current Output + + Return: + file object + """ + file_path = self.get_file_path(descriptor) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path)) + + return open(file_path, 'a+b') # read+write and in binary mode + + def get_file_path(self, descriptor): + local_cred_location = os.path.join( + self._temp_dir, + get_formatted_output_credentials_name(self._service_name, descriptor) + ) + return local_cred_location + + @staticmethod + def get_local_credentials_temp_dir(): + """Returns a temporary directory on the filesystem to store encrypted credentials. + + Will automatically create the new directory if it does not exist. + + Returns: + str: local path for stream_alert_secrets tmp directory + """ + temp_dir = os.path.join(tempfile.gettempdir(), "stream_alert_secrets") + + # Check if this item exists as a file, and remove it if it does + if os.path.isfile(temp_dir): + os.remove(temp_dir) + + # Create the folder on disk to store the credentials temporarily + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + return temp_dir + + +class SpooledTempfileDriver(CredentialsProvidingDriver, FileDescriptorProvider): + """Driver for fetching credentials that are stored in memory in file-like objects.""" + + SERVICE_SPOOLS = {} + + def __init__(self, service_name, region): + self._service_name = service_name + self._region = region + + def has_credentials(self, descriptor): + key = self.get_spool_cache_key(descriptor) + return key in type(self).SERVICE_SPOOLS + + def load_credentials(self, descriptor): + """Loads the credentials from a temporary spool.""" + key = self.get_spool_cache_key(descriptor) + if key not in type(self).SERVICE_SPOOLS: + LOGGER.error( + 'SpooledTempfileDriver failed to load_credentials: Spool "%s" does not exist?', + key + ) + return None + + spool = type(self).SERVICE_SPOOLS[key] + + spool.seek(0) + raw_data = spool.read() + + return Credentials(raw_data, True, self._region) + + def save_credentials(self, descriptor, credentials): + """Saves the credentials into a temporary spool. + + Args: + descriptor (str): Descriptor of the current Output + credentials (Credentials): Credentials object that is intended to be saved + + Return: + bool: True on success, False otherwise + """ + # Always store unencrypted because it's in memory. Saves calls to KMS and it's safe + # because other unrelated processes cannot read this memory (probably..) + if not credentials.is_encrypted(): + LOGGER.error('Error: Writing unencrypted credentials to disk is disallowed.') + return False + + raw_creds = credentials.data() + + spool = tempfile.SpooledTemporaryFile() + spool.write(raw_creds) + + key = self.get_spool_cache_key(descriptor) + type(self).SERVICE_SPOOLS[key] = spool + + return True + + @classmethod + def clear(cls): + """Clears all global spools. + + De-allocating the spools triggers garbage collection, which implicitly closes the + file handles. + """ + cls.SERVICE_SPOOLS = {} + + def offer_fileobj(self, descriptor): + """Opens a file-like temporary file spool and returns it. + + If you use the return value in a `with` statement block then the file descriptor + auto-close. + + NOTE: (!) This returns an ephemeral spool that is not attached to the caching mechanism + in save_credentials() and load_credentials() + + Args: + descriptor (str): Descriptor of the current Output + + Returns: + file object + """ + return tempfile.SpooledTemporaryFile(0, 'a+b') + + def get_spool_cache_key(self, descriptor): + return '{}/{}'.format(self._service_name, descriptor) + + +class EphemeralUnencryptedDriver(CredentialsProvidingDriver, CredentialsCachingDriver): + """Stores credentials UNENCRYPTED on the Python runtime stack. + + It is ephemeral and is only readable by the current Python process... hopefully. + """ + + CREDENTIALS_STORE = {} + + def __init__(self, service_name): + self._service_name = service_name + + def has_credentials(self, descriptor): + key = self.get_storage_key(descriptor) + return key in type(self).CREDENTIALS_STORE + + def load_credentials(self, descriptor): + key = self.get_storage_key(descriptor) + if key not in type(self).CREDENTIALS_STORE: + LOGGER.error( + 'EphemeralUnencryptedDriver failed to load_credentials: Key "%s" does not exist?', + key + ) + return None + + unencrypted_raw_creds = type(self).CREDENTIALS_STORE[key] + + return Credentials(unencrypted_raw_creds, False) + + def save_credentials(self, descriptor, credentials): + """Saves the credentials into static python memory. + + Args: + descriptor (str): Descriptor of the current Output + credentials (Credentials): Credentials object that is intended to be saved + + Return: + bool: True on success, False otherwise + """ + if credentials.is_encrypted(): + unencrypted_raw_creds = credentials.get_data_kms_decrypted() + else: + unencrypted_raw_creds = credentials.data() + + key = self.get_storage_key(descriptor) + type(self).CREDENTIALS_STORE[key] = unencrypted_raw_creds + return True + + @classmethod + def clear(cls): + cls.CREDENTIALS_STORE.clear() + + def get_storage_key(self, descriptor): + return '{}/{}'.format(self._service_name, descriptor) diff --git a/stream_alert/alert_processor/outputs/output_base.py b/stream_alert/alert_processor/outputs/output_base.py index 5ba43ee4e..687e04157 100644 --- a/stream_alert/alert_processor/outputs/output_base.py +++ b/stream_alert/alert_processor/outputs/output_base.py @@ -15,17 +15,13 @@ """ from abc import ABCMeta, abstractmethod from collections import namedtuple -import json -import os -import tempfile import requests from requests.exceptions import Timeout as ReqTimeout import urllib3 import backoff -import boto3 -from botocore.exceptions import ClientError +from stream_alert.alert_processor.outputs.credentials.provider import OutputCredentialsProvider from stream_alert.shared.backoff_handlers import ( backoff_handler, success_handler, @@ -120,13 +116,6 @@ class OutputDispatcher(object): """OutputDispatcher is the base class to handle routing alerts to outputs Public methods: - get_secrets_bucket_name: returns the name of the s3 bucket for secrets that - includes a unique prefix - output_cred_name: the name that is used to store the credentials both on s3 - and locally on disk in tmp - get_config_service: the name of the service used by the config to store any - configured outputs for this service. implemented by some subclasses, but - subclass is not required to implement format_output_config: returns a formatted version of the outputs configuration that is to be written to disk get_user_defined_properties: returns any properties for this output that must be @@ -145,33 +134,18 @@ class OutputDispatcher(object): _DEFAULT_REQUEST_TIMEOUT = 3.05 def __init__(self, config): - self.account_id = os.environ['AWS_ACCOUNT_ID'] self.region = REGION - self.secrets_bucket = '{}.streamalert.secrets'.format(os.environ['STREAMALERT_PREFIX']) self.config = config - @staticmethod - def _local_temp_dir(): - """Get the local tmp directory for caching the encrypted service credentials - - Returns: - str: local path for stream_alert_secrets tmp directory - """ - temp_dir = os.path.join(tempfile.gettempdir(), "stream_alert_secrets") - - # Check if this item exists as a file, and remove it if it does - if os.path.isfile(temp_dir): - os.remove(temp_dir) - - # Create the folder on disk to store the credentials temporarily - if not os.path.exists(temp_dir): - os.makedirs(temp_dir) - - return temp_dir + self._credentials_provider = OutputCredentialsProvider( + self.__service__, + config=config, + defaults=self._get_default_properties(), + region=self.region + ) def _load_creds(self, descriptor): - """First try to load the credentials from /tmp and then resort to pulling - the credentials from S3 if they are not cached locally + """Loads a dict of credentials relevant to this output descriptor Args: descriptor (str): unique identifier used to look up these credentials @@ -180,74 +154,7 @@ def _load_creds(self, descriptor): dict: the loaded credential info needed for sending alerts to this service or None if nothing gets loaded """ - local_cred_location = os.path.join(self._local_temp_dir(), - self.output_cred_name(descriptor)) - - # Creds are not cached locally, so get the encrypted blob from s3 - if not os.path.exists(local_cred_location): - if not self._get_creds_from_s3(local_cred_location, descriptor): - return - - # Open encrypted credential file - with open(local_cred_location, 'rb') as cred_file: - enc_creds = cred_file.read() - - # Get the decrypted credential json from kms and load into dict - # This could be None if the kms decryption fails, so check it - decrypted_creds = self._kms_decrypt(enc_creds) - if not decrypted_creds: - return - - creds_dict = json.loads(decrypted_creds) - - # Add any of the hard-coded default output props to this dict (ie: url) - defaults = self._get_default_properties() - if defaults: - creds_dict.update(defaults) - - return creds_dict - - def _get_creds_from_s3(self, cred_location, descriptor): - """Pull the encrypted credential blob for this service and destination from s3 - - Args: - cred_location (str): The tmp path on disk to to store the encrypted blob - descriptor (str): Service destination (ie: slack channel, pd integration) - - Returns: - bool: True if download of creds from s3 was a success - """ - try: - if not os.path.exists(os.path.dirname(cred_location)): - os.makedirs(os.path.dirname(cred_location)) - - client = boto3.client('s3', region_name=self.region) - with open(cred_location, 'wb') as cred_output: - client.download_fileobj(self.secrets_bucket, - self.output_cred_name(descriptor), - cred_output) - - return True - except ClientError as err: - LOGGER.exception('credentials for \'%s\' could not be downloaded ' - 'from S3: %s', self.output_cred_name(descriptor), - err.response) - - def _kms_decrypt(self, data): - """Decrypt data with AWS KMS. - - Args: - data (str): An encrypted ciphertext data blob - - Returns: - str: Decrypted json string - """ - try: - client = boto3.client('kms', region_name=self.region) - response = client.decrypt(CiphertextBlob=data) - return response['Plaintext'] - except ClientError as err: - LOGGER.error('an error occurred during credentials decryption: %s', err.response) + return self._credentials_provider.load_credentials(descriptor) @classmethod def _log_status(cls, success, descriptor): @@ -438,25 +345,6 @@ def _get_default_properties(cls): """ pass - @classmethod - def output_cred_name(cls, descriptor): - """Formats the output name for this credential by combining the service - and the descriptor. - - Args: - descriptor (str): Service destination (ie: slack channel, pd integration) - - Returns: - str: Formatted credential name (ie: slack_ryandchannel) - """ - cred_name = str(cls.__service__) - - # should descriptor be enforced in all rules? - if descriptor: - cred_name = '{}/{}'.format(cred_name, descriptor) - - return cred_name - @classmethod def format_output_config(cls, service_config, values): """Add this descriptor to the list of descriptor this service diff --git a/stream_alert/shared/helpers/aws_api_client.py b/stream_alert/shared/helpers/aws_api_client.py new file mode 100644 index 000000000..dcdfb37ef --- /dev/null +++ b/stream_alert/shared/helpers/aws_api_client.py @@ -0,0 +1,184 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import boto3 + +from botocore.exceptions import ClientError + +from stream_alert.shared.helpers.boto import default_config +from stream_alert.shared.logger import get_logger + +LOGGER = get_logger(__name__) + + +class AwsKms(object): + @staticmethod + def encrypt(plaintext_data, region, key_alias): + """Encrypts the given plaintext data using AWS KMS + + See: + https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html + + Args: + plaintext_data (str): The raw, unencrypted data to be encrypted + region (str): AWS region + key_alias (str): KMS Key Alias + + Returns: + string: The encrypted ciphertext + + Raises: + ClientError + """ + try: + key_id = 'alias/{}'.format(key_alias) + client = boto3.client('kms', config=default_config(region=region)) + response = client.encrypt(KeyId=key_id, Plaintext=plaintext_data) + return response['CiphertextBlob'] + except ClientError: + LOGGER.error('An error occurred during KMS encryption') + raise + + @staticmethod + def decrypt(ciphertext, region): + """Decrypts the given ciphertext using AWS KMS + + See: + https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html + + Args: + ciphertext (str): The raw, encrypted data to be decrypted + region (str): AWS region + + Returns: + string: The decrypted plaintext + + Raises: + ClientError + """ + try: + client = boto3.client('kms', config=default_config(region=region)) + response = client.decrypt(CiphertextBlob=ciphertext) + return response['Plaintext'] + except ClientError: + LOGGER.error('An error occurred during KMS decryption') + raise + + +class AwsS3(object): + @staticmethod + def head_bucket(bucket, region): + """Determines if given bucket exists with correct permissions. + + See: + https://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketHEAD.html + + Args: + bucket (str): AWS S3 bucket name + region (str): AWS Region + + Returns: + bool: True on success + + Raises: + ClientError; Raises when the bucket does not exist or is denying permission to access. + """ + try: + client = boto3.client('s3', config=default_config(region=region)) + client.head_bucket(Bucket=bucket) + except ClientError: + LOGGER.error('An error occurred during S3 HeadBucket') + raise + + @staticmethod + def create_bucket(bucket, region): + """Creates the given S3 bucket + + See: + https://docs.aws.amazon.com/cli/latest/reference/s3api/create-bucket.html + + Args: + bucket (str): The string name of the intended S3 bucket + region (str): AWS Region + + Returns: + bool: True on success + + Raises: + ClientError + """ + try: + client = boto3.client('s3', config=default_config(region=region)) + client.create_bucket(Bucket=bucket) + return True + except ClientError: + LOGGER.error('An error occurred during S3 CreateBucket') + raise + + @staticmethod + def put_object(object_data, bucket, key, region): + """Saves the given data into AWS S3 + + Args: + object_data (str): The raw object data to save + region (str): AWS region + bucket (str): AWS S3 bucket name + key (str): AWS S3 key name + + Returns: + bool: True on success + + Raises: + ClientError + """ + try: + client = boto3.client('s3', config=default_config(region=region)) + client.put_object(Body=object_data, Bucket=bucket, Key=key) + return True + except ClientError: + LOGGER.error('An error occurred during S3 PutObject') + raise + + @staticmethod + def download_fileobj(file_handle, bucket, key, region): + """Downloads the requested S3 object and saves it into the given file handle. + + This method also returns the downloaded payload. + + Args: + file_handle (File): A File-like object to save the downloaded contents + region (str): AWS region + bucket (str): AWS S3 bucket name + key (str): AWS S3 key name + + Returns: + str: The downloaded payload + + Raises: + ClientError + """ + try: + client = boto3.client('s3', config=default_config(region=region)) + client.download_fileobj( + bucket, + key, + file_handle + ) + + file_handle.seek(0) + return file_handle.read() + except ClientError: + LOGGER.error('An error occurred during S3 DownloadFileobj') + raise diff --git a/stream_alert_cli/outputs/handler.py b/stream_alert_cli/outputs/handler.py index 697728870..97bc5c118 100644 --- a/stream_alert_cli/outputs/handler.py +++ b/stream_alert_cli/outputs/handler.py @@ -14,9 +14,12 @@ limitations under the License. """ from stream_alert.shared.logger import get_logger -from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput +from stream_alert.alert_processor.outputs.output_base import ( + StreamAlertOutput, + OutputCredentialsProvider +) from stream_alert_cli.helpers import user_input -from stream_alert_cli.outputs.helpers import encrypt_and_push_creds_to_s3, output_exists +from stream_alert_cli.outputs.helpers import output_exists LOGGER = get_logger(__name__) @@ -62,12 +65,9 @@ def output_handler(options, config): if output_exists(output_config, props, service): return output_handler(options, config) - secrets_bucket = '{}.streamalert.secrets'.format(prefix) - secrets_key = output.output_cred_name(props['descriptor'].value) - - # Encrypt the creds and push them to S3 - # then update the local output configuration with properties - if not encrypt_and_push_creds_to_s3(region, secrets_bucket, secrets_key, props, kms_key_alias): + provider = OutputCredentialsProvider(service, config=config, region=region, prefix=prefix) + result = provider.save_credentials(props['descriptor'].value, kms_key_alias, props) + if not result: LOGGER.error('An error occurred while saving \'%s\' ' 'output configuration for service \'%s\'', props['descriptor'].value, options.service) diff --git a/stream_alert_cli/outputs/helpers.py b/stream_alert_cli/outputs/helpers.py index 1dbf817fd..517649414 100644 --- a/stream_alert_cli/outputs/helpers.py +++ b/stream_alert_cli/outputs/helpers.py @@ -13,83 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -import json - -import boto3 -from botocore.exceptions import ClientError from stream_alert.shared.logger import get_logger LOGGER = get_logger(__name__) -def encrypt_and_push_creds_to_s3(region, bucket, key, props, kms_key_alias): - """Construct a dictionary of the credentials we want to encrypt and send to s3 - - Args: - region (str): The aws region to use for boto3 client - bucket (str): The name of the s3 bucket to write the encrypted credentials to - key (str): ID for the s3 object to write the encrypted credentials to - props (OrderedDict): Contains various OutputProperty items - kms_key_alias (string): The KMS key alias to use for encryption of S3 objects - """ - creds = {name: prop.value - for (name, prop) in props.iteritems() if prop.cred_requirement} - - # Check if we have any creds to send to s3 - # Some services (ie: AWS) do not require this, so it's not an error - if not creds: - return True - - creds_json = json.dumps(creds) - enc_creds = kms_encrypt(region, creds_json, kms_key_alias) - return send_creds_to_s3(region, bucket, key, enc_creds) - - -def kms_encrypt(region, data, kms_key_alias): - """Encrypt data with AWS KMS. - - Args: - region (str): AWS region to use for boto3 client - data (str): json string to be encrypted - kms_key_alias (str): The KMS key alias to use for encryption of S3 objects - - Returns: - str: Encrypted ciphertext data blob - """ - try: - client = boto3.client('kms', region_name=region) - response = client.encrypt(KeyId='alias/{}'.format(kms_key_alias), - Plaintext=data) - return response['CiphertextBlob'] - except ClientError: - LOGGER.error('An error occurred during credential encryption') - raise - -def send_creds_to_s3(region, bucket, key, blob_data): - """Put the encrypted credential blob for this service and destination in s3 - - Args: - region (str): AWS region to use for boto3 client - bucket (str): The name of the s3 bucket to write the encrypted credentials to - key (str): ID for the s3 object to write the encrypted credentials to - blob_data (bytes): Cipher text blob from the kms encryption - """ - try: - client = boto3.client('s3', region_name=region) - client.put_object(Body=blob_data, Bucket=bucket, Key=key) - - return True - except ClientError as err: - LOGGER.error( - 'An error occurred while sending credentials to S3 for key \'%s\' ' - 'in bucket \'%s\': %s', - key, - bucket, - err.response['Error']['Message']) - return False - - def output_exists(config, props, service): """Determine if this service and destination combo has already been created diff --git a/stream_alert_cli/terraform/generate.py b/stream_alert_cli/terraform/generate.py index 2e45637cb..8fe5cb4b5 100644 --- a/stream_alert_cli/terraform/generate.py +++ b/stream_alert_cli/terraform/generate.py @@ -146,6 +146,7 @@ def generate_main(config, init=False): # Configure initial S3 buckets main_dict['resource']['aws_s3_bucket'] = { 'stream_alert_secrets': generate_s3_bucket( + # FIXME (derek.wang) DRY out by using OutputCredentialsProvider? bucket='{}.streamalert.secrets'.format(config['global']['account']['prefix']), logging=logging_bucket ), diff --git a/tests/unit/helpers/aws_mocks.py b/tests/unit/helpers/aws_mocks.py index 2e5c09c6a..944a16a1b 100644 --- a/tests/unit/helpers/aws_mocks.py +++ b/tests/unit/helpers/aws_mocks.py @@ -21,6 +21,7 @@ import boto3 from botocore.exceptions import ClientError +from stream_alert.shared.helpers.aws_api_client import AwsS3 class MockLambdaClient(object): """http://boto3.readthedocs.io/en/latest/reference/services/lambda.html""" @@ -229,11 +230,9 @@ def put_mock_s3_object(bucket, key, data, region='us-east-1'): data (str): the actual value to use for the object region (str): the aws region to use for this boto3 client """ - s3_client = boto3.client('s3', region_name=region) try: - # Check if the bucket exists before creating it - s3_client.head_bucket(Bucket=bucket) + AwsS3.head_bucket(bucket, region=region) except ClientError: - s3_client.create_bucket(Bucket=bucket) + AwsS3.create_bucket(bucket, region=region) - s3_client.put_object(Body=data, Bucket=bucket, Key=key, ServerSideEncryption='AES256') + AwsS3.put_object(data, bucket=bucket, key=key, region=region) diff --git a/tests/unit/stream_alert_alert_processor/helpers.py b/tests/unit/stream_alert_alert_processor/helpers.py index 724af08d2..074d63963 100644 --- a/tests/unit/stream_alert_alert_processor/helpers.py +++ b/tests/unit/stream_alert_alert_processor/helpers.py @@ -13,15 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os import json import random -import shutil -import tempfile - -import boto3 +from stream_alert.alert_processor.outputs.credentials.provider import LocalFileDriver from stream_alert.shared.alert import Alert +from stream_alert.shared.helpers.aws_api_client import AwsKms from tests.unit.helpers.aws_mocks import put_mock_s3_object @@ -83,19 +80,12 @@ def get_alert(context=None): def remove_temp_secrets(): """Remove the local secrets directory that may be left from previous runs""" - secrets_dirtemp_dir = os.path.join(tempfile.gettempdir(), 'stream_alert_secrets') - - # Check if the folder exists, and remove it if it does - if os.path.isdir(secrets_dirtemp_dir): - shutil.rmtree(secrets_dirtemp_dir) + LocalFileDriver.clear() def encrypt_with_kms(data, region, alias): """Encrypt the given data with KMS.""" - kms_client = boto3.client('kms', region_name=region) - response = kms_client.encrypt(KeyId=alias, Plaintext=data) - - return response['CiphertextBlob'] + return AwsKms.encrypt(data, region=region, key_alias=alias) def put_mock_creds(output_name, creds, bucket, region, alias): diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/credentials/__init__.py b/tests/unit/stream_alert_alert_processor/test_outputs/credentials/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py b/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py new file mode 100644 index 000000000..a37650967 --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py @@ -0,0 +1,742 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=abstract-class-instantiated,protected-access,attribute-defined-outside-init +import json +import os +from collections import OrderedDict + +from botocore.exceptions import ClientError +from mock import patch, MagicMock +from moto import mock_kms, mock_s3 +from nose.tools import ( + assert_true, + assert_equal, + assert_is_instance, + assert_is_not_none, + assert_false, + assert_is_none, +) + +from stream_alert.alert_processor.outputs.output_base import OutputProperty +from stream_alert.alert_processor.outputs.credentials.provider import ( + S3Driver, + LocalFileDriver, + Credentials, + OutputCredentialsProvider, + EphemeralUnencryptedDriver, SpooledTempfileDriver, get_formatted_output_credentials_name) +from tests.unit.stream_alert_alert_processor import ( + CONFIG, + KMS_ALIAS, + REGION, + MOCK_ENV) +from tests.unit.helpers.aws_mocks import put_mock_s3_object +from tests.unit.stream_alert_alert_processor.helpers import ( + encrypt_with_kms, + put_mock_creds, + remove_temp_secrets +) + + +# +# class Credentials Tests +# + + +class TestCredentialsEncrypted(object): + @mock_kms + def setup(self): + self._plaintext_payload = 'plaintext credentials' + self._encrypted_payload = encrypt_with_kms(self._plaintext_payload, REGION, KMS_ALIAS) + self._credentials = Credentials(self._encrypted_payload, is_encrypted=True, region=REGION) + + def test_is_encrypted(self): + """Credentials - Encrypted Credentials - Is Encrypted""" + assert_true(self._credentials.is_encrypted()) + + def test_is_data(self): + """Credentials - Encrypted Credentials - Data""" + assert_equal(self._credentials.data(), self._encrypted_payload) + + @mock_kms + def test_get_data_kms_decrypted(self): + """Credentials - Encrypted Credentials - KMS Decrypt""" + decrypted = self._credentials.get_data_kms_decrypted() + assert_equal(decrypted, self._plaintext_payload) + + def test_encrypt(self): + """Credentials - Encrypted Credentials - Encrypt + + Doubly-encrypting the credentials should do nothing. + """ + self._credentials.encrypt(REGION, KMS_ALIAS) + assert_equal(self._credentials.data(), self._encrypted_payload) + + @patch('boto3.client') + @patch('logging.Logger.exception') + def test_decrypt_kms_error(self, logging_exception, boto3): + """Credentials - Encrypted Credentials - KMS Decrypt - Errors if KMS Fails to Respond""" + + # We pretend that KMS errors out + boto3_client = MagicMock() + boto3.return_value = boto3_client + + response = MagicMock() + boto3_client.decrypt.side_effect = ClientError(response, 'kms_decrypt') + + assert_is_none(self._credentials.get_data_kms_decrypted()) + logging_exception.assert_called_with('an error occurred during credentials decryption') + + +class TestCredentialsUnencrypted(object): + def setup(self): + self._plaintext_payload = 'plaintext credentials' + self._credentials = Credentials(self._plaintext_payload, is_encrypted=False) + + def test_is_encrypted(self): + """Credentials - Plaintext Credentials - Is Encrypted""" + assert_false(self._credentials.is_encrypted()) + + def test_is_data(self): + """Credentials - Plaintext Credentials - Data""" + assert_equal(self._credentials.data(), self._plaintext_payload) + + @patch('logging.Logger.error') + def test_get_data_kms_decrypted(self, logging_error): + """Credentials - Plaintext Credentials - KMS Decrypt""" + assert_is_none(self._credentials.get_data_kms_decrypted()) + logging_error.assert_called_with('Cannot decrypt Credentials as they are already decrypted') + + @mock_kms + def test_encrypt(self): + """Credentials - Plaintext Credentials - Encrypt + + Doubly-encrypting the credentials should do nothing. + """ + self._credentials.encrypt(REGION, KMS_ALIAS) + + assert_true(self._credentials.is_encrypted()) + assert_equal(self._credentials.data(), 'InBsYWludGV4dCBjcmVkZW50aWFscyI=') + + +class TestCredentialsEmpty(object): + def setup(self): + self._plaintext_payload = '' + self._credentials = Credentials(self._plaintext_payload, is_encrypted=False) + + @mock_kms + def test_encrypt(self): + """Credentials - Empty Credentials - Encrypt - Does nothing when payload is empty""" + self._credentials.encrypt(REGION, KMS_ALIAS) + + assert_true(self._credentials.is_encrypted()) + assert_equal(self._credentials.data(), '') + + +# +# class OutputCredentialsProvider Tests +# + + +@patch.dict(os.environ, MOCK_ENV) +def test_constructor_loads_from_os_when_not_provided(): + """OutputCredentials - Constructor + + When not provided, prefix and aws account id are loaded from the OS Environment.""" + + provider = OutputCredentialsProvider('that_service_name', config=CONFIG, region=REGION) + assert_equal(provider._prefix, 'prefix') + assert_equal(provider.get_aws_account_id(), '123456789012') + + +@mock_s3 +class TestOutputCredentialsProvider(object): + def setup(self): + service_name = 'service' + defaults = { + 'property2': 'abcdef' + } + prefix = 'test_asdf' + aws_account_id = '1234567890' + + self._provider = OutputCredentialsProvider( + service_name, + config=CONFIG, + defaults=defaults, + region=REGION, + prefix=prefix, + aws_account_id=aws_account_id + ) + + # Pre-create the bucket so we dont get a "Bucket does not exist" error + s3_driver = S3Driver('test_asdf', 'service', REGION) + put_mock_s3_object(s3_driver.get_s3_secrets_bucket(), 'laskdjfaouhvawe', 'lafhawef', REGION) + + @mock_kms + def test_save_and_load_credentials(self): + """OutputCredentials - Save and Load Credentials + + Not only tests how save_credentials() interacts with load_credentials(), but also tests + that cred_requirement=False properties are not saved. Also tests that default values + are merged into the final credentials dict as appropriate.""" + + descriptor = 'test_save_and_load_credentials' + props = OrderedDict([ + ('property1', + OutputProperty(description='This is a property and not a cred so it will not save')), + ('property2', + OutputProperty(description='Neither will this')), + ('credential1', + OutputProperty(description='Hello world', + value='this is a super secret secret, shhhh!', + mask_input=True, + cred_requirement=True)), + ('credential2', + OutputProperty(description='This appears too!', + value='where am i?', + mask_input=True, + cred_requirement=True)), + ]) + + # Save credential + assert_true(self._provider.save_credentials(descriptor, KMS_ALIAS, props)) + + # Pull it out + creds_dict = self._provider.load_credentials(descriptor) + expectation = { + 'property2': 'abcdef', + 'credential1': 'this is a super secret secret, shhhh!', + 'credential2': 'where am i?', + } + assert_equal(creds_dict, expectation) + + @mock_kms + def test_load_credentials_multiple(self): + """OutputCredentials - Load Credentials Loads from Cache Driver + + This test ensures that we only hit S3 once during, and that subsequent calls are routed + to the Cache driver. Currently the cache driver is configured as Ephemeral.""" + + descriptor = 'test_load_credentials_pulls_from_cache' + props = OrderedDict([ + ('credential1', + OutputProperty(description='Hello world', + value='there is no cow level', + mask_input=True, + cred_requirement=True)), + ]) + + # Save credential + self._provider.save_credentials(descriptor, KMS_ALIAS, props) + + # Pull it out (Normal expected behavior) + creds_dict = self._provider.load_credentials(descriptor) + expectation = {'credential1': 'there is no cow level', 'property2': 'abcdef'} + assert_equal(creds_dict, expectation) + + # Now we yank the S3 driver out of the driver pool + # FIXME (derek.wang): Another way to do this is to install a spy on moto and make assertions + # on the number of times it is called. + assert_is_instance(self._provider._drivers[1], S3Driver) + self._provider._drivers[1] = None + self._provider._core_driver = None + + # Load again and see if it still is able to load without S3 + assert_equal(self._provider.load_credentials(descriptor), expectation) + + # Double-check; Examine the Driver guts and make sure that the EphemeralDriver has the + # value cached. + ep_driver = self._provider._drivers[0] + assert_is_instance(ep_driver, EphemeralUnencryptedDriver) + + assert_true(ep_driver.has_credentials(descriptor)) + creds = ep_driver.load_credentials(descriptor) + assert_equal(json.loads(creds.data())['credential1'], 'there is no cow level') + + @patch('logging.Logger.error') + def test_load_credentials_returns_none_on_driver_failure(self, logging_error): #pylint: disable=invalid-name + """OutputCredentials - Load Credentials Returns None on Driver Failure""" + descriptor = 'descriptive' + + # To pretend all drivers fail, we can just remove all of the drivers. + self._provider._drivers = [] + self._provider._core_driver = None + + creds_dict = self._provider.load_credentials(descriptor) + assert_is_none(creds_dict) + logging_error.assert_called_with('All drivers failed to retrieve credentials for [%s.%s]', + 'service', + descriptor) + +# +# Tests for S3Driver +# + + +class TestS3Driver(object): + def setup(self): + self._s3_driver = S3Driver('rawr', 'service_name', REGION) + + @patch('boto3.client') + @patch('logging.Logger.exception') + def test_load_credentials_s3_failure(self, logging_exception, boto3): + """S3Driver - Load String returns None on S3 Failure""" + descriptor = 'test_descriptor' + + # Pretend S3 fails to respond + boto3_client = MagicMock() + boto3.return_value = boto3_client + response = MagicMock() + boto3_client.download_fileobj.side_effect = ClientError(response, 's3_download_fileobj') + + assert_is_none(self._s3_driver.load_credentials(descriptor)) + logging_exception.assert_called_with( + "credentials for '%s' could not be downloaded from S3", + 'service_name/test_descriptor' + ) + + @mock_s3 + def test_load_credentials_plain_object(self): + """S3Driver - Load String from S3 + + In this test we save a simple string, unencrypted, into a mock S3 file. We use the + driver to pull out this payload verbatim.""" + test_data = 'encrypted credential test string' + descriptor = 'test_descriptor' + + # Stick some fake data into the credentials bucket file. + bucket_name = self._s3_driver.get_s3_secrets_bucket() + key = self._s3_driver.get_s3_key(descriptor) + put_mock_s3_object(bucket_name, key, test_data, REGION) + + credentials = self._s3_driver.load_credentials(descriptor) + + # (!) Notably, in this test the credential contents are not encrypted when setup. They + # are supposed to be encrypted PRIOR to putting it in. + assert_true(credentials.is_encrypted()) + assert_equal(credentials.data(), test_data) + + @mock_s3 + @mock_kms + def test_load_credentials_encrypted_credentials(self): + """S3Driver - Load Encrypted Credentials + + In this test we save a (more or less) real credentials payload using S3 mocking. We + use the driver to pull the payload out and ensure the returned Credentials object is + in a stable state, and that we can retrieve the decrypt credentials from this object.""" + descriptor = 'test_descriptor' + + bucket = self._s3_driver.get_s3_secrets_bucket() + key = self._s3_driver.get_s3_key(descriptor) + + creds = {'url': 'http://www.foo.bar/test', + 'token': 'token_to_encrypt'} + + # Save encrypted credentials + put_mock_creds(key, creds, bucket, REGION, KMS_ALIAS) + + credentials = self._s3_driver.load_credentials(descriptor) + + assert_is_not_none(credentials) + assert_true(credentials.is_encrypted()) + + loaded_creds = json.loads(credentials.get_data_kms_decrypted()) + + assert_equal(len(loaded_creds), 2) + assert_equal(loaded_creds['url'], u'http://www.foo.bar/test') + assert_equal(loaded_creds['token'], u'token_to_encrypt') + + def test_has_credentials(self): + """S3Driver - Has Credentials + + Not much of a test; we assume that S3 always has the credentials. + """ + assert_true(self._s3_driver.has_credentials('some_descriptor')) + + @mock_s3 + @mock_kms + def test_save_credentials_into_s3(self): + """S3Driver - Save Credentials + + We test a full cycle of using save_credentials() then subsequently pulling them out with + load_credentials().""" + creds = {'url': 'http://best.website.ever/test'} + input_credentials = Credentials(creds, is_encrypted=False, region=REGION) + descriptor = 'test_descriptor' + + # Annoyingly, moto needs us to create the bucket first + # We put a random unrelated object into the bucket and this will set up the bucket for us + put_mock_s3_object(self._s3_driver.get_s3_secrets_bucket(), 'aaa', 'bbb', REGION) + + result = self._s3_driver.save_credentials_into_s3(descriptor, input_credentials, KMS_ALIAS) + assert_true(result) + + credentials = self._s3_driver.load_credentials(descriptor) + + assert_is_not_none(credentials) + assert_true(credentials.is_encrypted()) + + loaded_creds = json.loads(credentials.get_data_kms_decrypted()) + + assert_equal(loaded_creds, creds) + + def test_save_credentials_into_s3_blank_credentials(self): + """S3Driver - Save Credentials does nothing when Credentials are Blank""" + input_credentials = Credentials('', is_encrypted=False, region=REGION) + descriptor = 'test_descriptor22' + + result = self._s3_driver.save_credentials_into_s3(descriptor, input_credentials, KMS_ALIAS) + assert_true(result) + + assert_is_none(self._s3_driver.load_credentials(descriptor)) + + def test_get_s3_secrets_bucket(self): + """S3Driver - Get S3 Secrets Bucket Name""" + assert_equal(self._s3_driver.get_s3_secrets_bucket(), 'rawr.streamalert.secrets') + + +class TestS3DriverWithFileDriver(object): + def setup(self): + service_name = 'test_service' + self._fs_driver = LocalFileDriver(REGION, service_name) + self._s3_driver = S3Driver('test_prefix', service_name, REGION, file_driver=self._fs_driver) + + @mock_s3 + @mock_kms + def test_load_credentials(self): + """S3Driver - With File Driver - Load Credentials - Pulls into LocalFileStore + + Here we use the S3Driver's caching ability to yank stuff into a local driver.""" + remove_temp_secrets() + + creds = {'my_secret': 'i ate two portions of biscuits and gravy'} + input_credentials = Credentials(creds, is_encrypted=False, region=REGION) + descriptor = 'test_descriptor' + + # Annoyingly, moto needs us to create the bucket first + # We put a random unrelated object into the bucket and this will set up the bucket for us + put_mock_s3_object(self._s3_driver.get_s3_secrets_bucket(), 'aaa', 'bbb', REGION) + + # First, check if the Local driver can find the credentials (we don't expect it to) + assert_false(self._fs_driver.has_credentials(descriptor)) + + # Save the credentials using S3 driver + result = self._s3_driver.save_credentials_into_s3(descriptor, input_credentials, KMS_ALIAS) + assert_true(result) + + # We still don't expect the Local driver to find the credentials + assert_false(self._fs_driver.has_credentials(descriptor)) + + # Use S3Driver to warm up the Local driver + self._s3_driver.load_credentials(descriptor) + + # Now we should be able to get the credentials from the local fs + assert_true(self._fs_driver.has_credentials(descriptor)) + credentials = self._fs_driver.load_credentials(descriptor) + + assert_is_not_none(credentials) + assert_true(credentials.is_encrypted()) + + loaded_creds = json.loads(credentials.get_data_kms_decrypted()) + + assert_equal(loaded_creds, creds) + + remove_temp_secrets() + + +# +# class LocalFileDriver Tests +# + + +def test_get_formatted_output_credentials_name(): + """LocalFileDriver - Get Formatted Output Credentials Name""" + name = get_formatted_output_credentials_name( + 'test_service_name', + 'test_descriptor' + ) + assert_equal(name, 'test_service_name/test_descriptor') + + +def test_get_load_credentials_temp_dir(): + """LocalFileDriver - Get Load Credentials Temp Dir""" + temp_dir = LocalFileDriver.get_local_credentials_temp_dir() + assert_equal(temp_dir.split('/')[-1], 'stream_alert_secrets') + + +def test_get_formatted_output_credentials_name_no_descriptor(): #pylint: disable=invalid-name + """LocalFileDriver - Get Formatted Output Credentials Name - No Descriptor""" + name = get_formatted_output_credentials_name( + 'test_service_name', + '' + ) + assert_equal(name, 'test_service_name') + + +class TestLocalFileDriver(object): + + def setup(self): + LocalFileDriver.clear() + self._fs_driver = LocalFileDriver(REGION, 'service') + + @staticmethod + def teardown(): + LocalFileDriver.clear() + + def test_save_and_has_credentials(self): + """LocalFileDriver - Save and Has Credentials""" + assert_false(self._fs_driver.has_credentials('descriptor')) + + credentials = Credentials('aaaa', True) # pretend it's encrypted + self._fs_driver.save_credentials('descriptor', credentials) + + assert_true(self._fs_driver.has_credentials('descriptor')) + + @mock_kms + def test_save_and_load_credentials(self): + """LocalFileDriver - Save and Load Credentials""" + raw_credentials = 'aaaa' + descriptor = 'descriptor' + + encrypted_raw_credentials = encrypt_with_kms(raw_credentials, REGION, KMS_ALIAS) + + credentials = Credentials(encrypted_raw_credentials, True, REGION) + assert_true(self._fs_driver.save_credentials(descriptor, credentials)) + + loaded_credentials = self._fs_driver.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_true(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.get_data_kms_decrypted(), raw_credentials) + + @mock_kms + def test_save_and_load_credentials_persists_statically(self): + """LocalFileDriver - Save and Load Credentials""" + raw_credentials = 'aaaa' + descriptor = 'descriptor' + + encrypted_raw_credentials = encrypt_with_kms(raw_credentials, REGION, KMS_ALIAS) + + credentials = Credentials(encrypted_raw_credentials, True, REGION) + assert_true(self._fs_driver.save_credentials(descriptor, credentials)) + + driver2 = LocalFileDriver(REGION, 'service') # Create a separate, identical driver + loaded_credentials = driver2.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_true(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.get_data_kms_decrypted(), raw_credentials) + + def test_save_errors_on_unencrypted(self): + """LocalFileDriver - Save Errors on Unencrypted Credentials""" + raw_credentials_dict = { + 'python': 'is very difficult', + 'someone': 'save meeeee', + } + descriptor = 'descriptor5' + raw_credentials = json.dumps(raw_credentials_dict) + + credentials = Credentials(raw_credentials, False, REGION) + + assert_false(self._fs_driver.save_credentials(descriptor, credentials)) + assert_false(self._fs_driver.has_credentials(descriptor)) + + def test_clear(self): + """LocalFileDriver - Clear Credentials""" + descriptor = 'descriptor' + + credentials = Credentials('aaaa', True, REGION) # pretend it's encrypted + self._fs_driver.save_credentials(descriptor, credentials) + + LocalFileDriver.clear() + + assert_false(self._fs_driver.has_credentials(descriptor)) + + +# +# class TestSpooledTempfileDriver tests +# + + +class TestSpooledTempfileDriver(object): + + def setup(self): + SpooledTempfileDriver.clear() + self._sp_driver = SpooledTempfileDriver('service', REGION) + + @staticmethod + def teardown(): + SpooledTempfileDriver.clear() + + def test_save_and_has_credentials(self): + """SpooledTempfileDriver - Save and Has Credentials""" + assert_false(self._sp_driver.has_credentials('descriptor')) + + credentials = Credentials('aaaa', True) # let's pretend they're encrypted + assert_true(self._sp_driver.save_credentials('descriptor', credentials)) + + assert_true(self._sp_driver.has_credentials('descriptor')) + + @mock_kms + def test_save_and_load_credentials(self): + """SpooledTempfileDriver - Save and Load Credentials""" + raw_credentials = 'aaaa' + descriptor = 'descriptor' + encrypted_raw_credentials = encrypt_with_kms(raw_credentials, REGION, KMS_ALIAS) + + credentials = Credentials(encrypted_raw_credentials, True, REGION) + assert_true(self._sp_driver.save_credentials(descriptor, credentials)) + + loaded_credentials = self._sp_driver.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_true(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.get_data_kms_decrypted(), raw_credentials) + + @mock_kms + def test_save_and_load_credentials_persists_statically(self): + """SpooledTempfileDriver - Save and Load Credentials""" + raw_credentials_dict = { + 'python': 'is very difficult', + 'someone': 'save meeeee', + } + descriptor = 'descriptor' + + raw_credentials = json.dumps(raw_credentials_dict) + encrypted_raw_credentials = encrypt_with_kms(raw_credentials, REGION, KMS_ALIAS) + + credentials = Credentials(encrypted_raw_credentials, True) + assert_true(self._sp_driver.save_credentials(descriptor, credentials)) + + driver2 = SpooledTempfileDriver('service', REGION) # Create a separate, identical driver + loaded_credentials = driver2.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_true(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.get_data_kms_decrypted(), raw_credentials) + + def test_save_errors_on_unencrypted(self): + """SpooledTempfileDriver - Save Errors on Unencrypted Credentials""" + raw_credentials = 'aaaa' + descriptor = 'descriptor5' + + credentials = Credentials(raw_credentials, False) + + assert_false(self._sp_driver.save_credentials(descriptor, credentials)) + assert_false(self._sp_driver.has_credentials(descriptor)) + + @patch('logging.Logger.error') + def test_load_credentials_nonexistent(self, logging_error): + """SpooledTempfileDriver - Load Credentials returns None on missing""" + assert_false(self._sp_driver.has_credentials('qwertyuiop')) + assert_is_none(self._sp_driver.load_credentials('qwertyuiop')) + logging_error.assert_called_with( + 'SpooledTempfileDriver failed to load_credentials: Spool "%s" does not exist?', + 'service/qwertyuiop' + ) + + def test_clear(self): + """SpooledTempfileDriver - Clear Credentials""" + descriptor = 'descriptor' + credentials = Credentials('aaaa', True) # pretend it's encrypted + + assert_true(self._sp_driver.save_credentials(descriptor, credentials)) + + SpooledTempfileDriver.clear() + + assert_false(self._sp_driver.has_credentials(descriptor)) + + +# +# class EphemeralUnencryptedDriver tests +# + +class TestEphemeralUnencryptedDriver(object): + + def setup(self): + EphemeralUnencryptedDriver.clear() + self._ep_driver = EphemeralUnencryptedDriver('service') + + @staticmethod + def teardown(): + EphemeralUnencryptedDriver.clear() + + def test_save_and_has_credentials(self): + """EphemeralUnencryptedDriver - Save and Has Credentials""" + assert_false(self._ep_driver.has_credentials('descriptor')) + + credentials = Credentials('aaaa', False) + assert_true(self._ep_driver.save_credentials('descriptor', credentials)) + + assert_true(self._ep_driver.has_credentials('descriptor')) + + def test_save_and_load_credentials(self): + """EphemeralUnencryptedDriver - Save and Load Credentials""" + descriptor = 'descriptor' + credentials = Credentials('aaaa', False) + assert_true(self._ep_driver.save_credentials(descriptor, credentials)) + + loaded_credentials = self._ep_driver.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_false(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.data(), 'aaaa') + + def test_save_and_load_credentials_persists_statically(self): + """EphemeralUnencryptedDriver - Save and Load Credentials""" + descriptor = 'descriptor' + credentials = Credentials('aaaa', False) + + assert_true(self._ep_driver.save_credentials(descriptor, credentials)) + + driver2 = EphemeralUnencryptedDriver('service') # Create a separate, identical driver + loaded_credentials = driver2.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_false(loaded_credentials.is_encrypted()) + assert_equal(loaded_credentials.data(), 'aaaa') + + @mock_kms + def test_save_automatically_decrypts(self): + """EphemeralUnencryptedDriver - Save Automatically Decrypts""" + raw_credentials_dict = { + 'python': 'is very difficult', + 'someone': 'save meeeee', + } + descriptor = 'descriptor5' + + raw_credentials = json.dumps(raw_credentials_dict) + encrypted_raw_credentials = encrypt_with_kms(raw_credentials, REGION, KMS_ALIAS) + + credentials = Credentials(encrypted_raw_credentials, True, REGION) + + assert_true(self._ep_driver.save_credentials(descriptor, credentials)) + + loaded_credentials = self._ep_driver.load_credentials(descriptor) + + assert_is_not_none(loaded_credentials) + assert_false(loaded_credentials.is_encrypted()) + assert_equal(json.loads(loaded_credentials.data()), raw_credentials_dict) + + def test_clear(self): + """EphemeralUnencryptedDriver - Clear Credentials""" + descriptor = 'descriptor' + + credentials = Credentials('aaaa', False) + self._ep_driver.save_credentials(descriptor, credentials) + + EphemeralUnencryptedDriver.clear() + + assert_false(self._ep_driver.has_credentials(descriptor)) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_carbonblack.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_carbonblack.py index dbc107e46..bb35df6f9 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_carbonblack.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_carbonblack.py @@ -16,24 +16,14 @@ # pylint: disable=no-self-use,unused-argument,attribute-defined-outside-init,protected-access from collections import OrderedDict -from mock import call, patch -from moto import mock_s3, mock_kms +from mock import call, patch, Mock, MagicMock from nose.tools import assert_false, assert_is_instance, assert_true from stream_alert.alert_processor.outputs import carbonblack from stream_alert.alert_processor.outputs.carbonblack import CarbonBlackOutput -from tests.unit.stream_alert_alert_processor import ( - CONFIG, - KMS_ALIAS, - MOCK_ENV, - REGION -) +from tests.unit.stream_alert_alert_processor import CONFIG from tests.unit.helpers.mocks import MockCBAPI -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @@ -46,22 +36,16 @@ class TestCarbonBlackOutput(object): 'ssl_verify': 'Y', 'token': '1234567890127a3d7f37f4153270bff41b105899'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = CarbonBlackOutput(CONFIG) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() def test_get_user_defined_properties(self): """CarbonBlackOutput - User Defined Properties""" diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_github.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_github.py index 1692efea0..e6e438a46 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_github.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_github.py @@ -16,21 +16,11 @@ # pylint: disable=protected-access,attribute-defined-outside-init,no-self-use import base64 -from mock import patch -from moto import mock_s3, mock_kms +from mock import patch, Mock, MagicMock from nose.tools import assert_false, assert_true, assert_equal, assert_is_not_none from stream_alert.alert_processor.outputs.github import GithubOutput -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @@ -39,26 +29,23 @@ class TestGithubOutput(object): DESCRIPTOR = 'unit_test_repo' SERVICE = 'github' OUTPUT = ':'.join([SERVICE, DESCRIPTOR]) - CREDS = {'username': 'unit_test_user', 'access_token': - 'unit_test_access_token', 'repository': 'unit_test_org/unit_test_repo', - 'labels': 'label1,label2'} - - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + CREDS = {'username': 'unit_test_user', + 'access_token': 'unit_test_access_token', + 'repository': 'unit_test_org/unit_test_repo', + 'labels': 'label1,label2', + 'api': 'https://api.github.com', + } + + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = GithubOutput(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() @patch('logging.Logger.info') @patch('requests.post') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py index b25f5c427..3d4123de7 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_jira.py @@ -14,21 +14,11 @@ limitations under the License. """ # pylint: disable=protected-access,attribute-defined-outside-init -from mock import patch, PropertyMock -from moto import mock_s3, mock_kms +from mock import patch, PropertyMock, Mock, MagicMock from nose.tools import assert_equal, assert_false, assert_true from stream_alert.alert_processor.outputs.jira import JiraOutput -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @@ -44,23 +34,17 @@ class TestJiraOutput(object): 'issue_type': 'Task', 'aggregate': 'yes'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = JiraOutput(None) self._dispatcher._base_url = self.CREDS['url'] - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() @patch('logging.Logger.info') @patch('requests.get') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_komand.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_komand.py index 3c8342db0..01d762a7b 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_komand.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_komand.py @@ -14,21 +14,11 @@ limitations under the License. """ # pylint: disable=protected-access,attribute-defined-outside-init -from mock import patch -from moto import mock_s3, mock_kms +from mock import patch, Mock, MagicMock from nose.tools import assert_false, assert_true from stream_alert.alert_processor.outputs.komand import KomandOutput -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @@ -40,22 +30,16 @@ class TestKomandutput(object): CREDS = {'url': 'http://komand.foo.bar', 'komand_auth_token': 'mocked_auth_token'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = KomandOutput(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() @patch('logging.Logger.info') @patch('requests.post') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py index 30843eef0..241777d78 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_output_base.py @@ -14,9 +14,7 @@ limitations under the License. """ # pylint: disable=abstract-class-instantiated,protected-access,attribute-defined-outside-init -import os - -from mock import Mock, patch +from mock import Mock, patch, MagicMock from moto import mock_kms, mock_s3 from nose.tools import ( assert_equal, @@ -27,6 +25,8 @@ ) from requests.exceptions import Timeout as ReqTimeout +from stream_alert.alert_processor.outputs.credentials.provider import \ + get_formatted_output_credentials_name from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -40,9 +40,7 @@ MOCK_ENV, REGION ) -from tests.unit.helpers.aws_mocks import put_mock_s3_object from tests.unit.stream_alert_alert_processor.helpers import ( - encrypt_with_kms, put_mock_creds, remove_temp_secrets ) @@ -116,6 +114,7 @@ def test_output_loading(): class TestOutputDispatcher(object): """Test class for OutputDispatcher""" + @patch.object(OutputDispatcher, '__service__', 'test_service') @patch.object(OutputDispatcher, '__abstractmethods__', frozenset()) @patch.dict('os.environ', MOCK_ENV) def setup(self): @@ -123,43 +122,19 @@ def setup(self): self._dispatcher = OutputDispatcher(CONFIG) self._descriptor = 'desc_test' - def test_local_temp_dir(self): - """OutputDispatcher - Local Temp Dir""" - temp_dir = self._dispatcher._local_temp_dir() - assert_equal(temp_dir.split('/')[-1], 'stream_alert_secrets') - - def test_output_cred_name(self): - """OutputDispatcher - Output Cred Name""" - output_name = self._dispatcher.output_cred_name('creds') - assert_equal(output_name, 'test_service/creds') - - @mock_s3 - def test_get_creds_from_s3(self): - """OutputDispatcher - Get Creds From S3""" - test_data = 'credential test string' - - bucket_name = self._dispatcher.secrets_bucket - key = self._dispatcher.output_cred_name(self._descriptor) - - local_cred_location = os.path.join(self._dispatcher._local_temp_dir(), key) - - put_mock_s3_object(bucket_name, key, test_data, REGION) - - self._dispatcher._get_creds_from_s3(local_cred_location, self._descriptor) - - with open(local_cred_location) as creds: - line = creds.readline() - - assert_equal(line, test_data) + @patch.object(OutputDispatcher, '__service__', 'test_service') + @patch.object(OutputDispatcher, '__abstractmethods__', frozenset()) + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def test_credentials_provider(self, provider_constructor): + """OutputDispatcher - Constructor""" + provider = MagicMock() + provider_constructor.return_value = provider - @mock_kms - def test_kms_decrypt(self): - """OutputDispatcher - KMS Decrypt""" - test_data = 'data to encrypt' - encrypted = encrypt_with_kms(test_data, REGION, KMS_ALIAS) - decrypted = self._dispatcher._kms_decrypt(encrypted) + _ = OutputDispatcher(CONFIG) - assert_equal(decrypted, test_data) + provider_constructor.assert_called_with('test_service', + config=CONFIG, defaults=None, region=REGION) + assert_equal(self._dispatcher._credentials_provider._service_name, 'test_service') @patch('logging.Logger.info') def test_log_status_success(self, log_mock): @@ -193,12 +168,17 @@ def test_check_http_response(self, mock_response): def test_load_creds(self): """OutputDispatcher - Load Credentials""" remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self._descriptor) + key = get_formatted_output_credentials_name( + 'test_service', + self._descriptor + ) creds = {'url': 'http://www.foo.bar/test', 'token': 'token_to_encrypt'} - put_mock_creds(output_name, creds, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) + put_mock_creds(key, creds, + self._dispatcher._credentials_provider._core_driver._bucket, + REGION, KMS_ALIAS) loaded_creds = self._dispatcher._load_creds(self._descriptor) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py index 75f662370..1d88d1d2e 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -14,8 +14,7 @@ limitations under the License. """ # pylint: disable=protected-access,attribute-defined-outside-init -from mock import patch, PropertyMock -from moto import mock_s3, mock_kms +from mock import patch, PropertyMock, Mock, MagicMock from nose.tools import assert_equal, assert_false, assert_true # import cProfile, pstats, StringIO @@ -24,17 +23,7 @@ PagerDutyOutputV2, PagerDutyIncidentOutput ) -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) - -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @@ -46,22 +35,16 @@ class TestPagerDutyOutput(object): CREDS = {'url': 'http://pagerduty.foo.bar/create_event.json', 'service_key': 'mocked_service_key'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = PagerDutyOutput(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() def test_get_default_properties(self): """PagerDutyOutput - Get Default Properties""" @@ -109,22 +92,16 @@ class TestPagerDutyOutputV2(object): CREDS = {'url': 'http://pagerduty.foo.bar/create_event.json', 'routing_key': 'mocked_routing_key'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = PagerDutyOutputV2(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() def test_get_default_properties(self): """PagerDutyOutputV2 - Get Default Properties""" @@ -182,23 +159,17 @@ class TestPagerDutyIncidentOutput(object): 'email_from': 'email@domain.com', 'integration_key': 'mocked_key'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + self._provider = provider self._dispatcher = PagerDutyIncidentOutput(None) self._dispatcher._base_url = self.CREDS['api'] - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() def test_get_default_properties(self): """PagerDutyIncidentOutput - Get Default Properties""" diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py index 038cbda92..9b369d366 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_phantom.py @@ -14,23 +14,10 @@ limitations under the License. """ # pylint: disable=protected-access,attribute-defined-outside-init -from mock import call, patch, PropertyMock -from moto import mock_s3, mock_kms +from mock import call, patch, PropertyMock, Mock, MagicMock from nose.tools import assert_false, assert_true - from stream_alert.alert_processor.outputs.phantom import PhantomOutput -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) - -from tests.unit.stream_alert_alert_processor.helpers import ( - get_alert, - put_mock_creds, - remove_temp_secrets -) - +from tests.unit.stream_alert_alert_processor.helpers import get_alert @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) class TestPhantomOutput(object): @@ -41,22 +28,17 @@ class TestPhantomOutput(object): CREDS = {'url': 'http://phantom.foo.bar', 'ph_auth_token': 'mocked_auth_token'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + + self._provider = provider self._dispatcher = PhantomOutput(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() @patch('logging.Logger.info') @patch('requests.get') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py index 9da54a493..a6ce6d3f1 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py @@ -15,22 +15,13 @@ """ # pylint: disable=protected-access,attribute-defined-outside-init,no-self-use from collections import Counter, OrderedDict - -from mock import patch -from moto import mock_s3, mock_kms +from mock import patch, Mock, MagicMock from nose.tools import assert_equal, assert_false, assert_true, assert_set_equal from stream_alert.alert_processor.outputs.slack import SlackOutput -from tests.unit.stream_alert_alert_processor import ( - KMS_ALIAS, - MOCK_ENV, - REGION -) from tests.unit.stream_alert_alert_processor.helpers import ( get_random_alert, get_alert, - put_mock_creds, - remove_temp_secrets ) @@ -42,22 +33,17 @@ class TestSlackOutput(object): OUTPUT = ':'.join([SERVICE, DESCRIPTOR]) CREDS = {'url': 'https://api.slack.com/web-hook-key'} - @patch.dict('os.environ', MOCK_ENV) - def setup(self): + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): """Setup before each method""" - self._mock_s3 = mock_s3() - self._mock_s3.start() - self._mock_kms = mock_kms() - self._mock_kms.start() + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + + self._provider = provider self._dispatcher = SlackOutput(None) - remove_temp_secrets() - output_name = self._dispatcher.output_cred_name(self.DESCRIPTOR) - put_mock_creds(output_name, self.CREDS, self._dispatcher.secrets_bucket, REGION, KMS_ALIAS) - - def teardown(self): - """Teardown after each method""" - self._mock_s3.stop() - self._mock_kms.stop() def test_format_message_single(self): """SlackOutput - Format Single Message - Slack""" diff --git a/tests/unit/stream_alert_cli/test_outputs.py b/tests/unit/stream_alert_cli/test_outputs.py deleted file mode 100644 index a0c3f5053..000000000 --- a/tests/unit/stream_alert_cli/test_outputs.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Copyright 2017-present, Airbnb Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import boto3 -from botocore.exceptions import ClientError -from mock import patch -from moto import mock_kms, mock_s3 -from nose.tools import assert_true, raises - -from stream_alert.alert_processor.outputs.output_base import OutputProperty -from stream_alert_cli.outputs.helpers import encrypt_and_push_creds_to_s3 - - -@mock_kms -@mock_s3 -@patch('stream_alert_cli.outputs.helpers.send_creds_to_s3') -def test_encrypt_and_push_creds_to_s3(send_mock): - """CLI - Outputs - Encrypt and push creds to s3""" - props = { - 'non-secret': OutputProperty( - description='short description of info needed', - value='http://this.url.value' - ) - } - - return_value = encrypt_and_push_creds_to_s3('us-east-1', 'bucket', 'key', props, 'test_alias') - - assert_true(return_value) - send_mock.assert_not_called() - - props['secret'] = OutputProperty( - description='short description of secret needed', - value='1908AGSG98A8908AG', - cred_requirement=True - ) - - # Create the bucket to hold the mock object being put - boto3.client('s3', region_name='us-east-1').create_bucket(Bucket='bucket') - - return_value = encrypt_and_push_creds_to_s3('us-east-1', 'bucket', 'key', props, 'test_alias') - - assert_true(return_value) - send_mock.assert_called() - - -@raises(ClientError) -@patch('boto3.client') -@patch('logging.Logger.error') -def test_encrypt_and_push_creds_to_s3_kms_failure(log_mock, boto_mock): - """CLI - Outputs - Encrypt and push creds to s3 - kms failure""" - props = { - 'secret': OutputProperty( - description='short description of secret needed', - value='1908AGSG98A8908AG', - cred_requirement=True)} - - err_response = { - 'Error': - { - 'Code': 100, - 'Message': 'BAAAD', - 'BucketName': 'bucket' - } - } - - # Add ClientError side_effect to mock - boto_mock.side_effect = ClientError(err_response, 'operation') - encrypt_and_push_creds_to_s3('us-east-1', 'bucket', 'key', props, 'test_alias') - - log_mock.assert_called_with('An error occurred during credential encryption') diff --git a/tests/unit/stream_alert_shared/test_aws_api_client.py b/tests/unit/stream_alert_shared/test_aws_api_client.py new file mode 100644 index 000000000..6f7c1a88f --- /dev/null +++ b/tests/unit/stream_alert_shared/test_aws_api_client.py @@ -0,0 +1,84 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import tempfile + +from botocore.exceptions import ClientError +from mock import patch +from moto import mock_kms, mock_s3 +from nose.tools import assert_equal, raises + +from stream_alert.shared.helpers.aws_api_client import AwsS3, AwsKms +from tests.unit.helpers.aws_mocks import put_mock_s3_object +from tests.unit.stream_alert_alert_processor import KMS_ALIAS, REGION + + +class TestAwsKms(object): + + @staticmethod + @mock_kms + def test_encrypt_decrypt(): + """AwsApiClient - AwsKms - encrypt/decrypt - Encrypt and push creds, then pull them down""" + secret = 'shhhhhh' + + ciphertext = AwsKms.encrypt(secret, region=REGION, key_alias=KMS_ALIAS) + response = AwsKms.decrypt(ciphertext, region=REGION) + + assert_equal(response, secret) + + @staticmethod + @raises(ClientError) + @patch('boto3.client') + def test_encrypt_kms_failure(boto_mock): + """AwsApiClient - AwsKms - Encrypt - KMS Failure""" + response = { + 'Error': { + 'ErrorCode': 400, + 'Message': "bad bucket" + } + } + boto_mock.side_effect = ClientError(response, 'operation') + AwsKms.encrypt('secret', region=REGION, key_alias=KMS_ALIAS) + + +class TestAwsS3(object): + + @staticmethod + @mock_s3 + def test_put_download(): + """AwsApiClient - AwsS3 - PutObject/Download - Upload then download object""" + payload = 'zzzzz' + bucket = 'bucket' + key = 'key' + + # Annoyingly, moto needs us to create the bucket first + # We put a random unrelated object into the bucket and this will set up the bucket for us + put_mock_s3_object(bucket, 'aaa', 'bbb', REGION) + + AwsS3.put_object(payload, bucket=bucket, key=key, region=REGION) + + with tempfile.SpooledTemporaryFile(0, 'a+b') as file_handle: + result = AwsS3.download_fileobj(file_handle, bucket=bucket, key=key, region=REGION) + + assert_equal(result, payload) + + @staticmethod + @raises(ClientError) + @mock_s3 + def test_put_object_s3_failure(): + """AwsApiClient - AwsS3 - PutObject - S3 Failure""" + + # S3 will automatically fail because the bucket has not been created yet. + AwsS3.put_object('zzzpayload', bucket='aaa', key='zzz', region=REGION)