Skip to content

Commit

Permalink
Merge pull request #118 from PUNCH-Cyber/errors
Browse files Browse the repository at this point in the history
Add Error class for improved error handling and context
  • Loading branch information
mlaferrera committed Aug 14, 2019
2 parents 229984e + 845e41d commit 4f8bda2
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 79 deletions.
1 change: 1 addition & 0 deletions stoq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .core import Stoq
from .data_classes import (
Error,
ArchiverResponse,
ExtractedPayload,
Payload,
Expand Down
89 changes: 52 additions & 37 deletions stoq/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
from pythonjsonlogger import jsonlogger # type: ignore
from logging.handlers import RotatingFileHandler
from typing import (
Coroutine,
Awaitable,
Dict,
AsyncGenerator,
List,
Expand All @@ -326,6 +326,7 @@

from .exceptions import StoqException
from stoq.data_classes import (
Error,
Payload,
PayloadMeta,
PayloadResults,
Expand Down Expand Up @@ -516,7 +517,6 @@ async def scan_request(
"""
add_start_dispatch = add_start_dispatch or []
errors: DefaultDict[str, List[str]] = defaultdict(list)
scan_queue = [(payload, add_start_dispatch) for payload in request.payloads]
hashes_seen: Set[str] = set(
[helpers.get_sha256(payload.content) for payload in request.payloads]
Expand All @@ -525,19 +525,16 @@ async def scan_request(
for _recursion_level in range(self.max_recursion + 1):
next_scan_queue: List[Tuple[Payload, List[str]]] = []
for payload, add_dispatch in scan_queue:
extracted, p_errors = await self._single_scan(
payload, add_dispatch, request
)
extracted = await self._single_scan(payload, add_dispatch, request)
# TODO: Add option for no-dedup
for ex in extracted:
ex_hash = helpers.get_sha256(ex.content)
if ex_hash not in hashes_seen:
hashes_seen.add(ex_hash)
next_scan_queue.append((ex, ex.payload_meta.dispatch_to))
errors = helpers.merge_dicts(errors, p_errors)
scan_queue = next_scan_queue

response = StoqResponse(request=request, errors=errors)
response = StoqResponse(request=request)

decorator_tasks = []
for plugin_name, decorator in self._loaded_decorator_plugins.items():
Expand Down Expand Up @@ -625,10 +622,9 @@ async def _consume(

async def _single_scan(
self, payload: Payload, add_dispatch: List[str], request: Request
) -> Tuple[List[Payload], DefaultDict[str, List[str]]]:
) -> List[Payload]:
# TODO: Figure out Request usage
extracted: List[Payload] = []
errors: DefaultDict[str, List[str]] = defaultdict(list)
dispatches: Set[str] = set().union( # type: ignore
add_dispatch, self.always_dispatch
)
Expand All @@ -638,26 +634,22 @@ async def _single_scan(
dispatch_tasks.append(self._get_dispatches(dispatcher, payload, request))
dispatch_results = await asyncio.gather(*dispatch_tasks)

worker_tasks: List[Coroutine] = [
worker_tasks: List[Awaitable] = [
self._worker_start(w, payload, request) for w in dispatches
]
for dispatcher_name, dispatched_workers, dispatch_error in dispatch_results:
for dispatcher_name, dispatched_workers in dispatch_results:
for dispatched_worker in dispatched_workers:
worker_tasks.append(
self._worker_start(dispatched_worker, payload, request)
)
if dispatch_error:
errors[dispatcher_name].append(dispatch_error)
worker_results = await asyncio.gather(*worker_tasks) # type: ignore
payload_results = PayloadResults.from_payload(payload)

for worker_name, worker_response, worker_error in worker_results:
for worker_name, worker_response in worker_results:
if worker_response is None:
if worker_error:
errors[worker_name].append(worker_error)
continue
elif worker_response.errors:
errors[worker_name].extend(worker_response.errors)
request.errors.extend(worker_response.errors)

if worker_response.results is not None:
payload_results.workers[worker_name] = worker_response.results
Expand All @@ -679,50 +671,68 @@ async def _single_scan(
archive_tasks.append(self._archive_payload(archiver, payload, request))
archive_results = await asyncio.gather(*archive_tasks)

for archiver_name, archiver_response, archiver_error in archive_results:
for archiver_name, archiver_response in archive_results:
if archiver_response is None:
if archiver_error:
errors[archiver_name].append(archiver_error)
continue
elif archiver_response.errors:
errors[archiver_name].extend(archiver_response.errors)
request.errors.extend(archiver_response.errors)
if archiver_response.results is not None:
payload_results.archivers[archiver_name] = archiver_response.results
request.results.append(payload_results)
return (extracted, errors)
return extracted

async def _archive_payload(
self, archiver: ArchiverPlugin, payload: Payload, request: Request
) -> Tuple[str, Union[ArchiverResponse, None], Union[str, None]]:
) -> Tuple[str, Union[ArchiverResponse, None]]:
archiver_name = archiver.config.get('Core', 'Name')
archiver_response: Union[ArchiverResponse, None] = None
payload.plugins_run['archivers'].append(archiver_name)
try:
archiver_response = await archiver.archive(payload, request)
except Exception as e:
msg = 'archiver:failed to archive'
self.log.exception(msg)
return (archiver_name, None, helpers.format_exc(e, msg=msg))
return (archiver_name, archiver_response, None)
request.errors.append(
Error(
payload_id=payload.payload_id,
plugin_name=archiver_name,
error=helpers.format_exc(e, msg=msg),
)
)
return (archiver_name, archiver_response)

async def _worker_start(
self, dispatched_worker: str, payload: Payload, request: Request
) -> Tuple[str, Union[WorkerResponse, None], Union[None, str]]:
) -> Tuple[str, Union[WorkerResponse, None]]:
extracted: List[Payload] = []
worker_response: Union[None, WorkerResponse] = None
try:
plugin = self.load_plugin(dispatched_worker)
except Exception as e:
msg = 'worker:failed to load'
self.log.exception(msg)
return (dispatched_worker, None, helpers.format_exc(e, msg=msg))
request.errors.append(
Error(
payload_id=payload.payload_id,
plugin_name=dispatched_worker,
error=helpers.format_exc(e, msg=msg),
)
)
return (dispatched_worker, worker_response)
payload.plugins_run['workers'].append(dispatched_worker)
try:
worker_response = await plugin.scan(payload, request) # type: ignore
except Exception as e:
msg = 'worker:failed to scan'
self.log.exception(msg)
return (dispatched_worker, None, helpers.format_exc(e, msg=msg))

return (dispatched_worker, worker_response, None)
request.errors.append(
Error(
payload_id=payload.payload_id,
plugin_name=dispatched_worker,
error=helpers.format_exc(e, msg=msg),
)
)
return (dispatched_worker, worker_response)

def _init_logger(
self,
Expand Down Expand Up @@ -771,7 +781,7 @@ def _init_logger(

async def _get_dispatches(
self, dispatcher: DispatcherPlugin, payload: Payload, request: Request
) -> Tuple[str, Union[Set[str], None], Union[str, None]]:
) -> Tuple[str, Union[Set[str], None]]:

dispatcher_name = dispatcher.config.get('Core', 'Name')
plugin_names: Set[str] = set()
Expand All @@ -784,9 +794,14 @@ async def _get_dispatches(
except Exception as e:
msg = 'dispatcher:failed to dispatch'
self.log.exception(msg)
return (dispatcher_name, None, helpers.format_exc(e, msg=msg))

return (dispatcher_name, plugin_names, None)
request.errors.append(
Error(
plugin_name=dispatcher_name,
error=helpers.format_exc(e, msg=msg),
payload_id=payload.payload_id,
)
)
return (dispatcher_name, plugin_names)

async def _apply_decorators(
self, decorator: DecoratorPlugin, response: StoqResponse
Expand All @@ -798,14 +813,15 @@ async def _apply_decorators(
except Exception as e:
msg = 'decorator'
self.log.exception(msg)
response.errors[plugin_name].append(helpers.format_exc(e, msg='decorator'))
error = Error(plugin_name=plugin_name, error=helpers.format_exc(e, msg=msg))
response.errors.append(error)
return response
if decorator_response is None:
return response
if decorator_response.results is not None:
response.decorators[plugin_name] = decorator_response.results
if decorator_response.errors:
response.errors[plugin_name].extend(decorator_response.errors)
response.errors.extend(decorator_response.errors)
return response

async def _save_result(
Expand Down Expand Up @@ -833,7 +849,6 @@ async def reconstruct_all_subresponses(
)
new_response = StoqResponse(
request=new_request,
errors=stoq_response.errors,
time=stoq_response.time,
scan_id=stoq_response.scan_id,
)
Expand Down
37 changes: 26 additions & 11 deletions stoq/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@
import stoq.helpers as helpers


class Error:
def __init__(
self, plugin_name: str, error: str, payload_id: Optional[str] = None
) -> None:
self.plugin_name = plugin_name
self.error = error
self.payload_id = payload_id

def __str__(self) -> str:
return helpers.dumps(self)

def __repr__(self):
return repr(self.__dict__)


class PayloadMeta:
def __init__(
self,
Expand Down Expand Up @@ -191,10 +206,12 @@ def __init__(
payloads: Optional[List[Payload]] = None,
request_meta: Optional[RequestMeta] = None,
results: Optional[List[PayloadResults]] = None,
errors: Optional[List[Error]] = None,
):
self.payloads = payloads or []
self.request_meta = request_meta or RequestMeta()
self.results = results or []
self.errors = errors or []

def __str__(self) -> str:
return helpers.dumps(self)
Expand All @@ -207,7 +224,6 @@ class StoqResponse:
def __init__(
self,
request: Request,
errors: DefaultDict[str, List[str]],
time: Optional[str] = None,
decorators: Optional[Dict[str, Dict]] = None,
scan_id: Optional[str] = None,
Expand All @@ -218,14 +234,13 @@ def __init__(
:param results: ``PayloadResults`` object of scanned payload
:param request_meta: ``RequetMeta`` object pertaining to original scan request
:param errors: Errors that may have occurred during lifecyle of the payload
:param time: ISO Formatted timestamp of scan
:param decorators: Decorator plugin results
"""
self.results = request.results
self.request_meta = request.request_meta
self.errors = errors
self.errors = request.errors
self.time: str = datetime.now().isoformat() if time is None else time
self.decorators = {} if decorators is None else decorators
self.scan_id = str(uuid.uuid4()) if scan_id is None else scan_id
Expand Down Expand Up @@ -279,7 +294,7 @@ def __init__(
self,
results: Optional[Dict] = None,
extracted: Optional[List[ExtractedPayload]] = None,
errors: Optional[List[str]] = None,
errors: Optional[List[Error]] = None,
) -> None:
"""
Expand All @@ -296,7 +311,7 @@ def __init__(
"""
self.results = results
self.extracted = [] if extracted is None else extracted
self.errors = [] if errors is None else errors
self.errors = errors or []

def __str__(self) -> str:
return helpers.dumps(self)
Expand All @@ -307,7 +322,7 @@ def __repr__(self):

class ArchiverResponse:
def __init__(
self, results: Optional[Dict] = None, errors: Optional[List[str]] = None
self, results: Optional[Dict] = None, errors: Optional[List[Error]] = None
) -> None:
"""
Expand All @@ -321,7 +336,7 @@ def __init__(
"""
self.results = results
self.errors = [] if errors is None else errors
self.errors = errors or []

def __str__(self) -> str:
return helpers.dumps(self)
Expand All @@ -335,7 +350,7 @@ def __init__(
self,
plugin_names: Optional[List[str]] = None,
meta: Optional[Dict] = None,
errors: Optional[List[str]] = None,
errors: Optional[List[Error]] = None,
) -> None:
"""
Expand All @@ -352,7 +367,7 @@ def __init__(
"""
self.plugin_names = [] if plugin_names is None else plugin_names
self.meta = {} if meta is None else meta
self.errors = [] if errors is None else errors
self.errors = errors or []

def __str__(self) -> str:
return helpers.dumps(self)
Expand All @@ -363,7 +378,7 @@ def __repr__(self):

class DecoratorResponse:
def __init__(
self, results: Optional[Dict] = None, errors: Optional[List[str]] = None
self, results: Optional[Dict] = None, errors: Optional[List[Error]] = None
) -> None:
"""
Object containing response from decorator plugins
Expand All @@ -377,7 +392,7 @@ def __init__(
"""
self.results = results
self.errors = [] if errors is None else errors
self.errors = errors or []

def __str__(self) -> str:
return helpers.dumps(self)
Expand Down
Loading

0 comments on commit 4f8bda2

Please sign in to comment.