Skip to content

Commit

Permalink
Better linting and static type-checking
Browse files Browse the repository at this point in the history
  • Loading branch information
Austin Byers committed Apr 13, 2018
1 parent 2ca8316 commit 4d84078
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ ignore-docstrings=yes
ignore-imports=no

# Minimum lines number of a similarity.
min-similarity-lines=4
min-similarity-lines=7


[SPELLING]
Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ install:
script:
- coverage run manage.py unit_test
- coverage report # Required coverage threshold specified in .coveragerc
- find . -name '*.py' -not -path './docs/source/*' -exec pylint '{}' + # Config in .pylintrc
- mypy . --ignore-missing-imports
- pylint lambda_functions rules tests *.py -j 1 # Config in .pylintrc
- mypy lambda_functions rules *.py --disallow-untyped-defs --ignore-missing-imports --warn-unused-ignores
- bandit -r . # Configuration in .bandit
- sphinx-build -W docs/source docs/build
after_success:
Expand Down
3 changes: 2 additions & 1 deletion lambda_functions/analyzer/analyzer_aws_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from botocore.exceptions import ClientError

if __package__:
from lambda_functions.analyzer.binary_info import BinaryInfo
# BinaryInfo is imported here just for the type annotation - the cyclic import will resolve
from lambda_functions.analyzer.binary_info import BinaryInfo # pylint: disable=cyclic-import
from lambda_functions.analyzer.common import LOGGER
else:
# mypy complains about duplicate definitions
Expand Down
4 changes: 2 additions & 2 deletions lambda_functions/analyzer/binary_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _download_from_s3(self) -> None:
self.bucket_name, self.object_key, self.download_path)
self.download_time_ms = (time.time() - start_time) * 1000

def __enter__(self):
def __enter__(self) -> Any: # mypy/typing doesn't support recursive type yet
"""Download the binary from S3 and run YARA analysis."""
self._download_from_s3()
self.computed_sha, self.computed_md5 = file_hash.compute_hashes(self.download_path)
Expand All @@ -71,7 +71,7 @@ def __enter__(self):

return self

def __exit__(self, exception_type, exception_value, traceback):
def __exit__(self, exception_type: Any, exception_value: Any, traceback: Any) -> None:
"""Shred the downloaded binary and delete it from disk."""
# Note: This runs even during exception handling (it is the "with" context).
subprocess.check_call(['shred', '--remove', self.download_path])
Expand Down
2 changes: 1 addition & 1 deletion lambda_functions/analyzer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _objects_to_analyze(event: Dict[str, Any]) -> Generator[Tuple[str, str], Non
yield from _s3_objects(event['Records'])


def analyze_lambda_handler(event: Dict[str, Any], lambda_context) -> Dict[str, Dict[str, Any]]:
def analyze_lambda_handler(event: Dict[str, Any], lambda_context: Any) -> Dict[str, Dict[str, Any]]:
"""Analyzer Lambda function entry point.
Args:
Expand Down
6 changes: 3 additions & 3 deletions lambda_functions/batcher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
import os
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import boto3

Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self, bucket_name: str, prefix: Optional[str],
self.finished = False # Have we finished enumerating all of the S3 bucket?

@property
def continuation_token(self):
def continuation_token(self) -> str:
return self.kwargs.get('ContinuationToken')

def next_page(self) -> List[str]:
Expand All @@ -187,7 +187,7 @@ def next_page(self) -> List[str]:
return [obj['Key'] for obj in response['Contents']]


def batch_lambda_handler(event: Dict[str, str], lambda_context) -> int:
def batch_lambda_handler(event: Dict[str, str], lambda_context: Any) -> int:
"""Entry point for the batch Lambda function.
Args:
Expand Down
10 changes: 5 additions & 5 deletions lambda_functions/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DOWNLOAD_ZIPFILE = 'lambda_downloader'


def _build_analyzer(target_directory):
def _build_analyzer(target_directory: str) -> None:
"""Build the YARA analyzer Lambda deployment package."""
print('Creating analyzer deploy package...')
pathlib.Path(os.path.join(ANALYZE_SOURCE, 'main.py')).touch()
Expand Down Expand Up @@ -57,23 +57,23 @@ def _build_analyzer(target_directory):
shutil.rmtree(temp_package_dir)


def _build_batcher(target_directory):
def _build_batcher(target_directory: str) -> None:
"""Build the batcher Lambda deployment package."""
print('Creating batcher deploy package...')
pathlib.Path(BATCH_SOURCE).touch() # Change last modified time to force new Lambda deploy
with zipfile.ZipFile(os.path.join(target_directory, BATCH_ZIPFILE + '.zip'), 'w') as pkg:
pkg.write(BATCH_SOURCE, os.path.basename(BATCH_SOURCE))


def _build_dispatcher(target_directory):
def _build_dispatcher(target_directory: str) -> None:
"""Build the dispatcher Lambda deployment package."""
print('Creating dispatcher deploy package...')
pathlib.Path(DISPATCH_SOURCE).touch()
with zipfile.ZipFile(os.path.join(target_directory, DISPATCH_ZIPFILE + '.zip'), 'w') as pkg:
pkg.write(DISPATCH_SOURCE, os.path.basename(DISPATCH_SOURCE))


def _build_downloader(target_directory):
def _build_downloader(target_directory: str) -> None:
"""Build the downloader Lambda deployment package."""
print('Creating downloader deploy package...')
pathlib.Path(DOWNLOAD_SOURCE).touch()
Expand All @@ -93,7 +93,7 @@ def _build_downloader(target_directory):
shutil.rmtree(temp_package_dir)


def build(target_directory, downloader=False):
def build(target_directory: str, downloader: bool = False) -> None:
"""Build Lambda deployment packages.
Args:
Expand Down
4 changes: 2 additions & 2 deletions lambda_functions/dispatcher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def _publish_metrics(batch_sizes: Dict[str, List[int]]) -> None:
CLOUDWATCH.put_metric_data(Namespace='BinaryAlert', MetricData=metric_data)


def dispatch_lambda_handler(_, lambda_context):
def dispatch_lambda_handler(_: Dict[str, Any], lambda_context: Any) -> None:
"""Dispatch Lambda function entry point.
Args:
_: Unused invocation event.
lambda_context: LambdaContext object with .get_remaining_time_in_millis().
"""
# Keep track of the batch sizes (one element for each invocation) for each target function.
batch_sizes = {config.lambda_name: [] for config in DISPATCH_CONFIGS}
batch_sizes: Dict[str, List[int]] = {config.lambda_name: [] for config in DISPATCH_CONFIGS}

# The maximum amount of time needed in the execution loop.
# This allows us to dispatch as long as possible while still staying under the time limit.
Expand Down
2 changes: 1 addition & 1 deletion lambda_functions/downloader/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _publish_metrics(receive_counts: List[int]) -> None:
)


def download_lambda_handler(event: Dict[str, Any], _) -> None:
def download_lambda_handler(event: Dict[str, Any], _: Any) -> None:
"""Lambda function entry point - copy a binary from CarbonBlack into the BinaryAlert S3 bucket.
Args:
Expand Down
14 changes: 7 additions & 7 deletions manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class BinaryAlertConfig(object):
VALID_CB_ENCRYPTED_TOKEN_FORMAT = r'\S{50,500}'
VALID_CB_URL_FORMAT = r'https?://\S+'

def __init__(self):
def __init__(self) -> None:
"""Parse the terraform.tfvars config file and make sure it contains every variable.
Raises:
Expand All @@ -98,7 +98,7 @@ def aws_region(self) -> str:
return self._config['aws_region']

@aws_region.setter
def aws_region(self, value: str):
def aws_region(self, value: str) -> None:
if not re.fullmatch(self.VALID_AWS_REGION_FORMAT, value, re.ASCII):
raise InvalidConfigError(
'aws_region "{}" does not match format {}'.format(
Expand All @@ -111,7 +111,7 @@ def name_prefix(self) -> str:
return self._config['name_prefix']

@name_prefix.setter
def name_prefix(self, value: str):
def name_prefix(self, value: str) -> None:
if not re.fullmatch(self.VALID_NAME_PREFIX_FORMAT, value, re.ASCII):
raise InvalidConfigError(
'name_prefix "{}" does not match format {}'.format(
Expand All @@ -124,7 +124,7 @@ def enable_carbon_black_downloader(self) -> int:
return self._config['enable_carbon_black_downloader']

@enable_carbon_black_downloader.setter
def enable_carbon_black_downloader(self, value: int):
def enable_carbon_black_downloader(self, value: int) -> None:
if value not in {0, 1}:
raise InvalidConfigError(
'enable_carbon_black_downloader "{}" must be either 0 or 1.'.format(value)
Expand All @@ -136,7 +136,7 @@ def carbon_black_url(self) -> str:
return self._config['carbon_black_url']

@carbon_black_url.setter
def carbon_black_url(self, value: str):
def carbon_black_url(self, value: str) -> None:
if not re.fullmatch(self.VALID_CB_URL_FORMAT, value, re.ASCII):
raise InvalidConfigError(
'carbon_black_url "{}" does not match format {}'.format(
Expand All @@ -149,7 +149,7 @@ def encrypted_carbon_black_api_token(self) -> str:
return self._config['encrypted_carbon_black_api_token']

@encrypted_carbon_black_api_token.setter
def encrypted_carbon_black_api_token(self, value: str):
def encrypted_carbon_black_api_token(self, value: str) -> None:
if not re.fullmatch(self.VALID_CB_ENCRYPTED_TOKEN_FORMAT, value, re.ASCII):
raise InvalidConfigError(
'encrypted_carbon_black_api_token "{}" does not match format {}'.format(
Expand Down Expand Up @@ -326,7 +326,7 @@ def save(self) -> None:
class Manager(object):
"""BinaryAlert management utility."""

def __init__(self):
def __init__(self) -> None:
"""Parse the terraform.tfvars config file."""
self._config = BinaryAlertConfig()

Expand Down
9 changes: 5 additions & 4 deletions rules/compile_rules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Compile all of the YARA rules into a single binary file."""
import os
from typing import Generator

import yara

RULES_DIR = os.path.dirname(os.path.realpath(__file__)) # Directory containing this file.


def _find_yara_files():
def _find_yara_files() -> Generator[str, None, None]:
"""Find all .yar[a] files in the rules directory.
Yields:
[string] YARA rule filepaths, relative to the rules root directory.
YARA rule filepaths, relative to the rules root directory.
"""
for root, _, files in os.walk(RULES_DIR):
for filename in files:
Expand All @@ -19,11 +20,11 @@ def _find_yara_files():
yield os.path.relpath(os.path.join(root, filename), start=RULES_DIR)


def compile_rules(target_path):
def compile_rules(target_path: str) -> None:
"""Compile YARA rules into a single binary rules file.
Args:
target_path: [String] Where to save the compiled rules file.
target_path: Where to save the compiled rules file.
"""
# Each rule file must be keyed by an identifying "namespace"; in our case the relative path.
yara_filepaths = {relative_path: os.path.join(RULES_DIR, relative_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/lambda_functions/analyzer/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def metadata(self):
return GOOD_FILE_METADATA if self.key == GOOD_S3_OBJECT_KEY else EVIL_FILE_METADATA


@mock.patch.dict(os.environ, {
@mock.patch.dict(os.environ, values={
'LAMBDA_TASK_ROOT': '/var/task',
'SQS_QUEUE_URL': MOCK_SQS_URL,
'YARA_MATCHES_DYNAMO_TABLE_NAME': MOCK_DYNAMO_TABLE_NAME,
Expand Down
2 changes: 1 addition & 1 deletion tests/lambda_functions/analyzer/yara_analyzer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
]


@mock.patch.dict(os.environ, {'LAMBDA_TASK_ROOT': '/var/task'})
@mock.patch.dict(os.environ, values={'LAMBDA_TASK_ROOT': '/var/task'})
class YaraAnalyzerTest(fake_filesystem_unittest.TestCase):
"""Uses the real YARA library to parse the test rules."""

Expand Down
4 changes: 2 additions & 2 deletions tests/lambda_functions/batcher/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests import common


@mock.patch.dict(os.environ, {
@mock.patch.dict(os.environ, values={
'BATCH_LAMBDA_NAME': 'test_batch_lambda_name',
'BATCH_LAMBDA_QUALIFIER': 'Production',
'OBJECTS_PER_MESSAGE': '2',
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_batcher_invoke_with_continuation(self):
])
])

@mock.patch.dict(os.environ, {'OBJECT_PREFIX': 'important'}) # type: ignore
@mock.patch.dict(os.environ, values={'OBJECT_PREFIX': 'important'})
def test_batcher_with_prefix(self):
"""Limit batch operation to object keys which start with the given prefix."""
def mock_list(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/lambda_functions/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _verify_filenames(self, archive_path: str, expected_filenames: Set[str],
subset: bool = False):
"""Verify the set of filenames in the zip archive matches the expected list."""
with zipfile.ZipFile(archive_path, 'r') as archive:
filenames = set(zip_info.filename for zip_info in archive.filelist) # type: ignore
filenames = set(zip_info.filename for zip_info in archive.filelist)
if subset:
self.assertTrue(expected_filenames.issubset(filenames))
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/lambda_functions/dispatcher/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUp(self):
'SQS_QUEUE_URLS': '{},{}'.format(url1, url2)
}

with mock.patch.dict(os.environ, mock_environ):
with mock.patch.dict(os.environ, values=mock_environ):
from lambda_functions.dispatcher import main
self.main = main

Expand Down

0 comments on commit 4d84078

Please sign in to comment.