Skip to content

Commit

Permalink
update normalization logic to store raw values instead of paths to va…
Browse files Browse the repository at this point in the history
…lues in json (#854)

* moving the normalization logic to shared

* updating unit tests for Normalizer class

* updating community rule for new normalization usage

* updating imports for new normalization key location

* rule helper removal that is now part of Normalizer class

* ThreatIntel updates to support new normalized value lookup

* updates to documentation
  • Loading branch information
ryandeivert committed Dec 18, 2018
1 parent 46724c9 commit 2f79eea
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 135 deletions.
7 changes: 3 additions & 4 deletions docs/source/rules.rst
Expand Up @@ -161,21 +161,20 @@ datatypes
.. code-block:: python
"""These rules apply to several different log types, defined in conf/normalized_types.json"""
from rules.helpers.base import fetch_values_by_datatype
from stream_alert.shared.rule import rule
from stream_alert.shared.normalize import Normalizer
@rule(datatypes=['sourceAddress'], outputs=['aws-sns:my-topic'])
def ip_watchlist_hit(record):
"""Source IP address matches watchlist."""
return '127.0.0.1' in fetch_values_by_datatype(record, 'sourceAddress')
return '127.0.0.1' in Normalizer.get_values_for_normalized_type(record, 'sourceAddress')
@rule(datatypes=['command'], outputs=['aws-sns:my-topic'])
def command_etc_shadow(record):
"""Command line arguments include /etc/shadow"""
return any(
'/etc/shadow' in cmd.lower()
for cmd in fetch_values_by_datatype(record, 'command')
for cmd in Normalizer.get_values_for_normalized_type(record, 'command')
)
logs
Expand Down
@@ -1,6 +1,6 @@
"""Detection of the right to left override unicode character U+202E in filename or process name."""
from rules.helpers.base import fetch_values_by_datatype
from stream_alert.shared.rule import rule
from stream_alert.shared.normalize import Normalizer


@rule(datatypes=['command', 'filePath', 'processPath', 'fileName'])
Expand All @@ -22,22 +22,22 @@ def right_to_left_character(rec):
# Unicode character U+202E, right-to-left-override (RLO)
rlo = u'\u202e'

commands = fetch_values_by_datatype(rec, 'command')
commands = Normalizer.get_values_for_normalized_type(rec, 'command')
for command in commands:
if isinstance(command, unicode) and rlo in command:
return True

file_paths = fetch_values_by_datatype(rec, 'filePath')
file_paths = Normalizer.get_values_for_normalized_type(rec, 'filePath')
for file_path in file_paths:
if isinstance(file_path, unicode) and rlo in file_path:
return True

process_paths = fetch_values_by_datatype(rec, 'processPath')
process_paths = Normalizer.get_values_for_normalized_type(rec, 'processPath')
for process_path in process_paths:
if isinstance(process_path, unicode) and rlo in process_path:
return True

file_names = fetch_values_by_datatype(rec, 'fileName')
file_names = Normalizer.get_values_for_normalized_type(rec, 'fileName')
for file_name in file_names:
if isinstance(file_name, unicode) and rlo in file_name:
return True
Expand Down
28 changes: 0 additions & 28 deletions rules/helpers/base.py
Expand Up @@ -20,7 +20,6 @@
import time
import pathlib2

from stream_alert.shared import NORMALIZATION_KEY
from stream_alert.shared.utils import ( # pylint: disable=unused-import
# Import some utility functions which are useful for rules as well
get_first_key,
Expand Down Expand Up @@ -142,33 +141,6 @@ def last_hour(unixtime, hours=1):
return int(time.time()) - int(unixtime) <= seconds if unixtime else False


def fetch_values_by_datatype(rec, datatype):
"""Fetch values of normalized_type.
Args:
rec (dict): parsed payload of any log
datatype (str): normalized type user interested
Returns:
(list) The values of normalized types
"""
results = []
if not rec.get(NORMALIZATION_KEY):
return results

if datatype not in rec[NORMALIZATION_KEY]:
return results

for original_keys in rec[NORMALIZATION_KEY][datatype]:
result = rec
if isinstance(original_keys, list):
for original_key in original_keys:
result = result[original_key]
results.append(result)

return results


def data_has_value(data, search_value):
"""Recursively search for a given value.
Expand Down
7 changes: 4 additions & 3 deletions stream_alert/alert_processor/main.py
Expand Up @@ -20,11 +20,12 @@
from botocore.exceptions import ClientError

from stream_alert.alert_processor.outputs.output_base import StreamAlertOutput
from stream_alert.shared import backoff_handlers, NORMALIZATION_KEY, resources
from stream_alert.shared import backoff_handlers, resources
from stream_alert.shared.alert import Alert, AlertCreationError
from stream_alert.shared.alert_table import AlertTable
from stream_alert.shared.config import load_config
from stream_alert.shared.logger import get_logger
from stream_alert.shared.normalize import Normalizer

LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -143,8 +144,8 @@ def run(self, event):

# Remove normalization key from the record.
# TODO: Consider including this in at least some outputs, e.g. default Athena firehose
if NORMALIZATION_KEY in alert.record:
del alert.record[NORMALIZATION_KEY]
if Normalizer.NORMALIZATION_KEY in alert.record:
del alert.record[Normalizer.NORMALIZATION_KEY]

result = self._send_to_outputs(alert)
self._update_table(alert, result)
Expand Down
4 changes: 2 additions & 2 deletions stream_alert/classifier/classifier.py
Expand Up @@ -17,12 +17,12 @@
import logging

from stream_alert.classifier.clients import FirehoseClient, SQSClient
from stream_alert.classifier.normalize import NORMALIZATION_KEY, Normalizer
from stream_alert.classifier.parsers import get_parser
from stream_alert.classifier.payload.payload_base import StreamPayload
from stream_alert.shared import config, CLASSIFIER_FUNCTION_NAME as FUNCTION_NAME
from stream_alert.shared.logger import get_logger
from stream_alert.shared.metrics import MetricLogger
from stream_alert.shared.normalize import Normalizer


LOGGER = get_logger(__name__)
Expand Down Expand Up @@ -212,7 +212,7 @@ def _log_metrics(self):
MetricLogger.NORMALIZED_RECORDS,
sum(
1 for payload in self._payloads
for log in payload.parsed_records if log.get(NORMALIZATION_KEY)
for log in payload.parsed_records if log.get(Normalizer.NORMALIZATION_KEY)
)
)
MetricLogger.log_metric(
Expand Down
34 changes: 7 additions & 27 deletions stream_alert/rules_engine/threat_intel.py
Expand Up @@ -21,13 +21,13 @@
from boto3.dynamodb.types import TypeDeserializer
from botocore.exceptions import ClientError, ParamValidationError
from netaddr import IPNetwork
from stream_alert.shared import NORMALIZATION_KEY
from stream_alert.shared.backoff_handlers import (
backoff_handler,
success_handler,
giveup_handler
)
from stream_alert.shared.logger import get_logger
from stream_alert.shared.normalize import Normalizer
from stream_alert.shared.utils import in_network, valid_ip


Expand Down Expand Up @@ -288,27 +288,6 @@ def _is_excluded_ioc(self, ioc_type, ioc_value):

return ioc_value in exclusions

@classmethod
def _extract_values_by_keys(cls, record, key_values):
"""Return a value from the record given its path of Keys
Args:
record (dict): Record from which to extract values
key_values (list<list>): List of lists with keys in path to values
"""
values = set()
for original_keys in key_values:
value = record
for original_key in original_keys:
value = value[original_key]

if not value: # ensure the value is not falsy/empty
continue

values.add(str(value).lower())

return values

def _extract_ioc_values(self, payloads):
"""Instance method to extract IOC info from the record based on normalized keys
Expand All @@ -317,22 +296,23 @@ def _extract_ioc_values(self, payloads):
normalized data
Returns:
list: Return a list of RecordIOC instances.
dict: Map of ioc values to the source record and type of ioc
"""
ioc_values = defaultdict(list)
for payload in payloads:
record = payload['record']
if NORMALIZATION_KEY not in record:
if Normalizer.NORMALIZATION_KEY not in record:
continue
for normalized_key, original_key_values in record[NORMALIZATION_KEY].iteritems():
# Lookup mapped IOC type based on normalized CEF type
normalized_values = record[Normalizer.NORMALIZATION_KEY]
for normalized_key, values in normalized_values.iteritems():
# Look up mapped IOC type based on normalized CEF type
ioc_type = self._ioc_config.get(normalized_key)
if not ioc_type:
LOGGER.debug('Skipping undefined IOC type for normalized key: %s',
normalized_key)
continue

for value in self._extract_values_by_keys(record, original_key_values):
for value in values:
# Skip excluded IOCs
if self._is_excluded_ioc(ioc_type, value):
continue
Expand Down
1 change: 0 additions & 1 deletion stream_alert/shared/__init__.py
Expand Up @@ -5,6 +5,5 @@
CLASSIFIER_FUNCTION_NAME = 'classifier'
RULES_ENGINE_FUNCTION_NAME = 'rules_engine'
RULE_PROMOTION_NAME = 'rule_promotion'
NORMALIZATION_KEY = 'streamalert:normalization'

CLUSTERED_FUNCTIONS = {CLASSIFIER_FUNCTION_NAME}
Expand Up @@ -15,7 +15,6 @@
"""
import logging

from stream_alert.shared import NORMALIZATION_KEY
from stream_alert.shared.config import TopLevelConfigKeys
from stream_alert.shared.logger import get_logger

Expand All @@ -27,6 +26,8 @@
class Normalizer(object):
"""Normalizer class to handle log key normalization in payloads"""

NORMALIZATION_KEY = 'streamalert:normalization'

# Store the normalized CEF types mapping to original keys from the records
_types_config = dict()

Expand All @@ -39,30 +40,30 @@ def match_types(cls, record, normalized_types):
normalized_types (dict): Normalized types mapping
Returns:
dict: A dict of normalized_types with original key names
dict: A dict of normalized keys with a list of values
Example:
record={
'region': 'region_name',
'region': 'us-east-1',
'detail': {
'awsRegion': 'region_name'
'awsRegion': 'us-west-2'
}
}
normalized_types={
'region': ['region', 'awsRegion']
}
return={
'region': [['region'], ['detail', 'awsRegion']]
'region': ['us-east-1', 'us-west-2']
}
"""
return {
key: list(cls._extract_paths(record, keys_to_normalize))
key: set(cls._extract_values(record, set(keys_to_normalize)))
for key, keys_to_normalize in normalized_types.iteritems()
}

@classmethod
def _extract_paths(cls, record, keys_to_normalize, path=None):
def _extract_values(cls, record, keys_to_normalize):
"""Recursively extract lists of path parts from a dictionary
Args:
Expand All @@ -73,17 +74,19 @@ def _extract_paths(cls, record, keys_to_normalize, path=None):
Yields:
list: Parts of path in dictionary that contain normalized keys
"""
# Cast the JSON array to a set for quicker lookups
keys_to_normalize = set(keys_to_normalize)
path = path or []
for key, value in record.iteritems():
temp_path = [item for item in path]
temp_path.append(key)
if key in keys_to_normalize:
yield temp_path
if isinstance(value, dict):
for nested_path in cls._extract_paths(value, keys_to_normalize, temp_path):
yield nested_path
if isinstance(value, dict): # If this is a dict, look for nested
for nested_value in cls._extract_values(value, keys_to_normalize):
yield nested_value

if key not in keys_to_normalize:
continue

if isinstance(value, list): # If this is a list of values, return all of them
for item in value:
yield item
else:
yield value

@classmethod
def normalize(cls, record, log_type):
Expand All @@ -99,7 +102,20 @@ def normalize(cls, record, log_type):
return

# Add normalized keys to the record
record.update({NORMALIZATION_KEY: cls.match_types(record, log_normalized_types)})
record.update({cls.NORMALIZATION_KEY: cls.match_types(record, log_normalized_types)})

@classmethod
def get_values_for_normalized_type(cls, record, datatype):
"""Fetch values by normalized_type.
Args:
record (dict): parsed payload of any log
datatype (str): normalized type being found
Returns:
set: The values for the normalized type specified
"""
return set(record.get(cls.NORMALIZATION_KEY, {}).get(datatype, set()))

@classmethod
def load_from_config(cls, config):
Expand Down
4 changes: 2 additions & 2 deletions stream_alert/shared/utils.py
Expand Up @@ -4,8 +4,8 @@
from netaddr import IPAddress, IPNetwork
from netaddr.core import AddrFormatError

from stream_alert.shared import NORMALIZATION_KEY
from stream_alert.shared.logger import get_logger
from stream_alert.shared.normalize import Normalizer


LOGGER = get_logger(__name__)
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_keys(data, search_key, max_matches=-1):
# helper may fetch info from normalization if there are keyname conflict.
# For example, Key name 'userName' is both existed as a normalized key defined
# in conf/normalized_types.json and cloudtrail record schemas.
if key == NORMALIZATION_KEY:
if key == Normalizer.NORMALIZATION_KEY:
continue
if val and isinstance(val, _CONTAINER_TYPES):
containers.append(val)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/stream_alert_alert_processor/test_main.py
Expand Up @@ -24,9 +24,9 @@

from stream_alert.alert_processor.main import AlertProcessor, handler
from stream_alert.alert_processor.outputs.output_base import OutputDispatcher
from stream_alert.shared import NORMALIZATION_KEY
from stream_alert.shared.alert import Alert
from stream_alert.shared.config import load_config
from stream_alert.shared.normalize import Normalizer
from tests.unit.stream_alert_alert_processor import (
ALERTS_TABLE,
MOCK_ENV
Expand All @@ -52,7 +52,7 @@ def setup(self):
self.processor = AlertProcessor()
self.alert = Alert(
'hello_world',
{'abc': 123, NORMALIZATION_KEY: {}},
{'abc': 123, Normalizer.NORMALIZATION_KEY: {}},
{'slack:unit-test-channel'}
)

Expand Down

0 comments on commit 2f79eea

Please sign in to comment.