diff --git a/docs/source/rules.rst b/docs/source/rules.rst index e9cbfc3f9..707cd1290 100644 --- a/docs/source/rules.rst +++ b/docs/source/rules.rst @@ -161,21 +161,20 @@ datatypes .. code-block:: python """These rules apply to several different log types, defined in conf/normalized_types.json""" - from rules.helpers.base import fetch_values_by_datatype from stream_alert.shared.rule import rule + from stream_alert.shared.normalize import Normalizer @rule(datatypes=['sourceAddress'], outputs=['aws-sns:my-topic']) def ip_watchlist_hit(record): """Source IP address matches watchlist.""" - return '127.0.0.1' in fetch_values_by_datatype(record, 'sourceAddress') - + return '127.0.0.1' in Normalizer.get_values_for_normalized_type(record, 'sourceAddress') @rule(datatypes=['command'], outputs=['aws-sns:my-topic']) def command_etc_shadow(record): """Command line arguments include /etc/shadow""" return any( '/etc/shadow' in cmd.lower() - for cmd in fetch_values_by_datatype(record, 'command') + for cmd in Normalizer.get_values_for_normalized_type(record, 'command') ) logs diff --git a/rules/community/mitre_attack/defense_evasion/multi/obfuscated_files_or_information/right_to_left_character.py b/rules/community/mitre_attack/defense_evasion/multi/obfuscated_files_or_information/right_to_left_character.py index 57d4bde84..e444a37fa 100644 --- a/rules/community/mitre_attack/defense_evasion/multi/obfuscated_files_or_information/right_to_left_character.py +++ b/rules/community/mitre_attack/defense_evasion/multi/obfuscated_files_or_information/right_to_left_character.py @@ -1,6 +1,6 @@ """Detection of the right to left override unicode character U+202E in filename or process name.""" -from rules.helpers.base import fetch_values_by_datatype from stream_alert.shared.rule import rule +from stream_alert.shared.normalize import Normalizer @rule(datatypes=['command', 'filePath', 'processPath', 'fileName']) @@ -22,22 +22,22 @@ def right_to_left_character(rec): # Unicode character U+202E, right-to-left-override (RLO) rlo = u'\u202e' - commands = fetch_values_by_datatype(rec, 'command') + commands = Normalizer.get_values_for_normalized_type(rec, 'command') for command in commands: if isinstance(command, unicode) and rlo in command: return True - file_paths = fetch_values_by_datatype(rec, 'filePath') + file_paths = Normalizer.get_values_for_normalized_type(rec, 'filePath') for file_path in file_paths: if isinstance(file_path, unicode) and rlo in file_path: return True - process_paths = fetch_values_by_datatype(rec, 'processPath') + process_paths = Normalizer.get_values_for_normalized_type(rec, 'processPath') for process_path in process_paths: if isinstance(process_path, unicode) and rlo in process_path: return True - file_names = fetch_values_by_datatype(rec, 'fileName') + file_names = Normalizer.get_values_for_normalized_type(rec, 'fileName') for file_name in file_names: if isinstance(file_name, unicode) and rlo in file_name: return True diff --git a/rules/helpers/base.py b/rules/helpers/base.py index 00ec19f55..795c380f4 100644 --- a/rules/helpers/base.py +++ b/rules/helpers/base.py @@ -20,7 +20,6 @@ import time import pathlib2 -from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.utils import ( # pylint: disable=unused-import # Import some utility functions which are useful for rules as well get_first_key, @@ -142,33 +141,6 @@ def last_hour(unixtime, hours=1): return int(time.time()) - int(unixtime) <= seconds if unixtime else False -def fetch_values_by_datatype(rec, datatype): - """Fetch values of normalized_type. - - Args: - rec (dict): parsed payload of any log - datatype (str): normalized type user interested - - Returns: - (list) The values of normalized types - """ - results = [] - if not rec.get(NORMALIZATION_KEY): - return results - - if datatype not in rec[NORMALIZATION_KEY]: - return results - - for original_keys in rec[NORMALIZATION_KEY][datatype]: - result = rec - if isinstance(original_keys, list): - for original_key in original_keys: - result = result[original_key] - results.append(result) - - return results - - def data_has_value(data, search_value): """Recursively search for a given value. diff --git a/stream_alert/alert_processor/main.py b/stream_alert/alert_processor/main.py index 11ee4bcb4..17be31431 100644 --- a/stream_alert/alert_processor/main.py +++ b/stream_alert/alert_processor/main.py @@ -20,11 +20,12 @@ from botocore.exceptions import ClientError from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput -from stream_alert.shared import backoff_handlers, NORMALIZATION_KEY, resources +from stream_alert.shared import backoff_handlers, resources from stream_alert.shared.alert import Alert, AlertCreationError from stream_alert.shared.alert_table import AlertTable from stream_alert.shared.config import load_config from stream_alert.shared.logger import get_logger +from stream_alert.shared.normalize import Normalizer LOGGER = get_logger(__name__) @@ -143,8 +144,8 @@ def run(self, event): # Remove normalization key from the record. # TODO: Consider including this in at least some outputs, e.g. default Athena firehose - if NORMALIZATION_KEY in alert.record: - del alert.record[NORMALIZATION_KEY] + if Normalizer.NORMALIZATION_KEY in alert.record: + del alert.record[Normalizer.NORMALIZATION_KEY] result = self._send_to_outputs(alert) self._update_table(alert, result) diff --git a/stream_alert/classifier/classifier.py b/stream_alert/classifier/classifier.py index ac97a3215..f9cbe7441 100644 --- a/stream_alert/classifier/classifier.py +++ b/stream_alert/classifier/classifier.py @@ -17,12 +17,12 @@ import logging from stream_alert.classifier.clients import FirehoseClient, SQSClient -from stream_alert.classifier.normalize import NORMALIZATION_KEY, Normalizer from stream_alert.classifier.parsers import get_parser from stream_alert.classifier.payload.payload_base import StreamPayload from stream_alert.shared import config, CLASSIFIER_FUNCTION_NAME as FUNCTION_NAME from stream_alert.shared.logger import get_logger from stream_alert.shared.metrics import MetricLogger +from stream_alert.shared.normalize import Normalizer LOGGER = get_logger(__name__) @@ -212,7 +212,7 @@ def _log_metrics(self): MetricLogger.NORMALIZED_RECORDS, sum( 1 for payload in self._payloads - for log in payload.parsed_records if log.get(NORMALIZATION_KEY) + for log in payload.parsed_records if log.get(Normalizer.NORMALIZATION_KEY) ) ) MetricLogger.log_metric( diff --git a/stream_alert/rules_engine/threat_intel.py b/stream_alert/rules_engine/threat_intel.py index dbcc6a454..770fcefb2 100644 --- a/stream_alert/rules_engine/threat_intel.py +++ b/stream_alert/rules_engine/threat_intel.py @@ -21,13 +21,13 @@ from boto3.dynamodb.types import TypeDeserializer from botocore.exceptions import ClientError, ParamValidationError from netaddr import IPNetwork -from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.backoff_handlers import ( backoff_handler, success_handler, giveup_handler ) from stream_alert.shared.logger import get_logger +from stream_alert.shared.normalize import Normalizer from stream_alert.shared.utils import in_network, valid_ip @@ -288,27 +288,6 @@ def _is_excluded_ioc(self, ioc_type, ioc_value): return ioc_value in exclusions - @classmethod - def _extract_values_by_keys(cls, record, key_values): - """Return a value from the record given its path of Keys - - Args: - record (dict): Record from which to extract values - key_values (list): List of lists with keys in path to values - """ - values = set() - for original_keys in key_values: - value = record - for original_key in original_keys: - value = value[original_key] - - if not value: # ensure the value is not falsy/empty - continue - - values.add(str(value).lower()) - - return values - def _extract_ioc_values(self, payloads): """Instance method to extract IOC info from the record based on normalized keys @@ -317,22 +296,23 @@ def _extract_ioc_values(self, payloads): normalized data Returns: - list: Return a list of RecordIOC instances. + dict: Map of ioc values to the source record and type of ioc """ ioc_values = defaultdict(list) for payload in payloads: record = payload['record'] - if NORMALIZATION_KEY not in record: + if Normalizer.NORMALIZATION_KEY not in record: continue - for normalized_key, original_key_values in record[NORMALIZATION_KEY].iteritems(): - # Lookup mapped IOC type based on normalized CEF type + normalized_values = record[Normalizer.NORMALIZATION_KEY] + for normalized_key, values in normalized_values.iteritems(): + # Look up mapped IOC type based on normalized CEF type ioc_type = self._ioc_config.get(normalized_key) if not ioc_type: LOGGER.debug('Skipping undefined IOC type for normalized key: %s', normalized_key) continue - for value in self._extract_values_by_keys(record, original_key_values): + for value in values: # Skip excluded IOCs if self._is_excluded_ioc(ioc_type, value): continue diff --git a/stream_alert/shared/__init__.py b/stream_alert/shared/__init__.py index 9390abad5..d079ce690 100644 --- a/stream_alert/shared/__init__.py +++ b/stream_alert/shared/__init__.py @@ -5,6 +5,5 @@ CLASSIFIER_FUNCTION_NAME = 'classifier' RULES_ENGINE_FUNCTION_NAME = 'rules_engine' RULE_PROMOTION_NAME = 'rule_promotion' -NORMALIZATION_KEY = 'streamalert:normalization' CLUSTERED_FUNCTIONS = {CLASSIFIER_FUNCTION_NAME} diff --git a/stream_alert/classifier/normalize.py b/stream_alert/shared/normalize.py similarity index 70% rename from stream_alert/classifier/normalize.py rename to stream_alert/shared/normalize.py index 36a52e24e..c1ebe30b6 100644 --- a/stream_alert/classifier/normalize.py +++ b/stream_alert/shared/normalize.py @@ -15,7 +15,6 @@ """ import logging -from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.config import TopLevelConfigKeys from stream_alert.shared.logger import get_logger @@ -27,6 +26,8 @@ class Normalizer(object): """Normalizer class to handle log key normalization in payloads""" + NORMALIZATION_KEY = 'streamalert:normalization' + # Store the normalized CEF types mapping to original keys from the records _types_config = dict() @@ -39,13 +40,13 @@ def match_types(cls, record, normalized_types): normalized_types (dict): Normalized types mapping Returns: - dict: A dict of normalized_types with original key names + dict: A dict of normalized keys with a list of values Example: record={ - 'region': 'region_name', + 'region': 'us-east-1', 'detail': { - 'awsRegion': 'region_name' + 'awsRegion': 'us-west-2' } } normalized_types={ @@ -53,16 +54,16 @@ def match_types(cls, record, normalized_types): } return={ - 'region': [['region'], ['detail', 'awsRegion']] + 'region': ['us-east-1', 'us-west-2'] } """ return { - key: list(cls._extract_paths(record, keys_to_normalize)) + key: set(cls._extract_values(record, set(keys_to_normalize))) for key, keys_to_normalize in normalized_types.iteritems() } @classmethod - def _extract_paths(cls, record, keys_to_normalize, path=None): + def _extract_values(cls, record, keys_to_normalize): """Recursively extract lists of path parts from a dictionary Args: @@ -73,17 +74,19 @@ def _extract_paths(cls, record, keys_to_normalize, path=None): Yields: list: Parts of path in dictionary that contain normalized keys """ - # Cast the JSON array to a set for quicker lookups - keys_to_normalize = set(keys_to_normalize) - path = path or [] for key, value in record.iteritems(): - temp_path = [item for item in path] - temp_path.append(key) - if key in keys_to_normalize: - yield temp_path - if isinstance(value, dict): - for nested_path in cls._extract_paths(value, keys_to_normalize, temp_path): - yield nested_path + if isinstance(value, dict): # If this is a dict, look for nested + for nested_value in cls._extract_values(value, keys_to_normalize): + yield nested_value + + if key not in keys_to_normalize: + continue + + if isinstance(value, list): # If this is a list of values, return all of them + for item in value: + yield item + else: + yield value @classmethod def normalize(cls, record, log_type): @@ -99,7 +102,20 @@ def normalize(cls, record, log_type): return # Add normalized keys to the record - record.update({NORMALIZATION_KEY: cls.match_types(record, log_normalized_types)}) + record.update({cls.NORMALIZATION_KEY: cls.match_types(record, log_normalized_types)}) + + @classmethod + def get_values_for_normalized_type(cls, record, datatype): + """Fetch values by normalized_type. + + Args: + record (dict): parsed payload of any log + datatype (str): normalized type being found + + Returns: + set: The values for the normalized type specified + """ + return set(record.get(cls.NORMALIZATION_KEY, {}).get(datatype, set())) @classmethod def load_from_config(cls, config): diff --git a/stream_alert/shared/utils.py b/stream_alert/shared/utils.py index 35d42286c..e8c130eb2 100644 --- a/stream_alert/shared/utils.py +++ b/stream_alert/shared/utils.py @@ -4,8 +4,8 @@ from netaddr import IPAddress, IPNetwork from netaddr.core import AddrFormatError -from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.logger import get_logger +from stream_alert.shared.normalize import Normalizer LOGGER = get_logger(__name__) @@ -133,7 +133,7 @@ def get_keys(data, search_key, max_matches=-1): # helper may fetch info from normalization if there are keyname conflict. # For example, Key name 'userName' is both existed as a normalized key defined # in conf/normalized_types.json and cloudtrail record schemas. - if key == NORMALIZATION_KEY: + if key == Normalizer.NORMALIZATION_KEY: continue if val and isinstance(val, _CONTAINER_TYPES): containers.append(val) diff --git a/tests/unit/stream_alert_alert_processor/test_main.py b/tests/unit/stream_alert_alert_processor/test_main.py index 9ce843097..200cf5d92 100644 --- a/tests/unit/stream_alert_alert_processor/test_main.py +++ b/tests/unit/stream_alert_alert_processor/test_main.py @@ -24,9 +24,9 @@ from stream_alert.alert_processor.main import AlertProcessor, handler from stream_alert.alert_processor.outputs.output_base import OutputDispatcher -from stream_alert.shared import NORMALIZATION_KEY from stream_alert.shared.alert import Alert from stream_alert.shared.config import load_config +from stream_alert.shared.normalize import Normalizer from tests.unit.stream_alert_alert_processor import ( ALERTS_TABLE, MOCK_ENV @@ -52,7 +52,7 @@ def setup(self): self.processor = AlertProcessor() self.alert = Alert( 'hello_world', - {'abc': 123, NORMALIZATION_KEY: {}}, + {'abc': 123, Normalizer.NORMALIZATION_KEY: {}}, {'slack:unit-test-channel'} ) diff --git a/tests/unit/streamalert/classifier/test_normalizer.py b/tests/unit/stream_alert_shared/test_normalizer.py similarity index 74% rename from tests/unit/streamalert/classifier/test_normalizer.py rename to tests/unit/stream_alert_shared/test_normalizer.py index d3eee9f23..a02f14b1d 100644 --- a/tests/unit/streamalert/classifier/test_normalizer.py +++ b/tests/unit/stream_alert_shared/test_normalizer.py @@ -16,7 +16,7 @@ from mock import patch from nose.tools import assert_equal -from stream_alert.classifier.normalize import Normalizer +from stream_alert.shared.normalize import Normalizer class TestNormalizer(object): @@ -35,11 +35,11 @@ def _test_record(cls): 'awsRegion': 'region_name', 'source': '1.1.1.2', 'userIdentity': { - "userName": "Alice", - "invokedBy": "signin.amazonaws.com" + 'userName': 'Alice', + 'invokedBy': 'signin.amazonaws.com' } }, - 'sourceIPAddress': '1.1.1.2' + 'sourceIPAddress': '1.1.1.3' } def test_match_types(self): @@ -50,9 +50,9 @@ def test_match_types(self): 'ipv4': ['destination', 'source', 'sourceIPAddress'] } expected_results = { - 'sourceAccount': [['account']], - 'ipv4': [['sourceIPAddress'], ['detail', 'source']], - 'region': [['region'], ['detail', 'awsRegion']] + 'sourceAccount': {123456}, + 'ipv4': {'1.1.1.2', '1.1.1.3'}, + 'region': {'region_name'} } results = Normalizer.match_types(self._test_record(), normalized_types) @@ -67,18 +67,32 @@ def test_match_types_multiple(self): 'userName': ['userName', 'owner', 'invokedBy'] } expected_results = { - 'account': [['account']], - 'ipv4': [['sourceIPAddress'], ['detail', 'source']], - 'region': [['region'], ['detail', 'awsRegion']], - 'userName': [ - ['detail', 'userIdentity', 'userName'], - ['detail', 'userIdentity', 'invokedBy'] - ] + 'account': {123456}, + 'ipv4': {'1.1.1.2', '1.1.1.3'}, + 'region': {'region_name'}, + 'userName': {'Alice', 'signin.amazonaws.com'} } results = Normalizer.match_types(self._test_record(), normalized_types) assert_equal(results, expected_results) + def test_match_types_list(self): + """Normalizer - Match Types, List of Values""" + normalized_types = { + 'ipv4': ['sourceIPAddress'], + } + expected_results = { + 'ipv4': {'1.1.1.2', '1.1.1.3'} + } + + test_record = { + 'account': 123456, + 'sourceIPAddress': ['1.1.1.2', '1.1.1.3'] + } + + results = Normalizer.match_types(test_record, normalized_types) + assert_equal(results, expected_results) + def test_normalize(self): """Normalizer - Normalize""" log_type = 'cloudtrail' @@ -108,10 +122,10 @@ def test_normalize(self): "invokedBy": "signin.amazonaws.com" } }, - 'sourceIPAddress': '1.1.1.2', + 'sourceIPAddress': '1.1.1.3', 'streamalert:normalization': { - 'region': [['region'], ['detail', 'awsRegion']], - 'sourceAccount': [['account']] + 'region': {'region_name'}, + 'sourceAccount': {123456} } } @@ -148,9 +162,9 @@ def test_normalize_bad_normalized_key(self): "invokedBy": "signin.amazonaws.com" } }, - 'sourceIPAddress': '1.1.1.2', + 'sourceIPAddress': '1.1.1.3', 'streamalert:normalization': { - 'bad_type': [], + 'bad_type': set(), } } @@ -158,6 +172,27 @@ def test_normalize_bad_normalized_key(self): Normalizer.normalize(record, log_type) assert_equal(record, expected_record) + def test_get_values_for_normalized_type(self): + """Normalizer - Get Values for Normalized Type""" + expected_result = {'1.1.1.3'} + record = { + 'sourceIPAddress': '1.1.1.3', + 'streamalert:normalization': { + 'ip_v4': expected_result, + } + } + + assert_equal(Normalizer.get_values_for_normalized_type(record, 'ip_v4'), expected_result) + + def test_get_values_for_normalized_type_none(self): + """Normalizer - Get Values for Normalized Type, None""" + record = { + 'sourceIPAddress': '1.1.1.3', + 'streamalert:normalization': {} + } + + assert_equal(Normalizer.get_values_for_normalized_type(record, 'ip_v4'), set()) + def test_load_from_config(self): """Normalizer - Load From Config""" config = { diff --git a/tests/unit/streamalert/rules_engine/test_threat_intel.py b/tests/unit/streamalert/rules_engine/test_threat_intel.py index d5545c970..cb37d315d 100644 --- a/tests/unit/streamalert/rules_engine/test_threat_intel.py +++ b/tests/unit/streamalert/rules_engine/test_threat_intel.py @@ -90,8 +90,8 @@ def _sample_payload(self): }, 'source': '1.1.1.2', 'streamalert:normalization': { - 'sourceAddress': [['detail', 'sourceIPAddress'], ['source']], - 'userName': [['detail', 'userIdentity', 'userName']] + 'sourceAddress': {'1.1.1.2'}, + 'userName': {'alice'} } } } @@ -123,8 +123,8 @@ def test_threat_detection(self): }, 'source': '1.1.1.2', 'streamalert:normalization': { - 'sourceAddress': [['detail', 'sourceIPAddress'], ['source']], - 'userName': [['detail', 'userIdentity', 'userName']] + 'sourceAddress': {'1.1.1.2'}, + 'userName': {'alice'} }, 'streamalert:ioc': { 'ip': {'1.1.1.2'} @@ -410,26 +410,6 @@ def test_is_excluded_ioc_ip(self): assert_equal(self._threat_intel._is_excluded_ioc('ip', '1.2.3.20'), False) assert_equal(self._threat_intel._is_excluded_ioc('ip', '1.2.3.15'), True) - def test_extract_values_by_keys(self): - """ThreatIntel - Extract Values By Keys""" - record = { - 'region': 'us-east-1', - 'detail': { - 'eventName': 'ConsoleLogin', - 'sourceIPAddress': None - }, - 'source': '1.1.1.2' - } - - keys = [['detail', 'sourceIPAddress'], ['source']] - - expected_result = [ - '1.1.1.2' - ] - - result = list(ThreatIntel._extract_values_by_keys(record, keys)) - assert_equal(result, expected_result) - def test_extract_ioc_values(self): """ThreatIntel - Extract IOC Values""" payloads = [self._sample_payload]