diff --git a/pyoaev/apis/__init__.py b/pyoaev/apis/__init__.py index afcd77c..ba25cc4 100644 --- a/pyoaev/apis/__init__.py +++ b/pyoaev/apis/__init__.py @@ -13,6 +13,7 @@ from .organization import * # noqa: F401,F403 from .payload import * # noqa: F401,F403 from .security_platform import * # noqa: F401,F403 +from .signature import * # noqa: F401,F403 from .tag import * # noqa: F401,F403 from .team import * # noqa: F401,F403 from .user import * # noqa: F401,F403 diff --git a/pyoaev/apis/signature.py b/pyoaev/apis/signature.py new file mode 100644 index 0000000..bed3e9b --- /dev/null +++ b/pyoaev/apis/signature.py @@ -0,0 +1,339 @@ +"""Signature callback API — transport layer for compiled signature payloads.""" + +import json +import logging +import time +from typing import Any + +from pydantic import ValidationError + +from pyoaev import exceptions as exc +from pyoaev.base import RESTManager, RESTObject +from pyoaev.exceptions import SignatureTransmissionError +from pyoaev.signatures.models import SignatureCallbackPayload + + +class Signature(RESTObject): + """REST object placeholder for signature callback responses.""" + + _id_attr = None + + +class SignatureApiManager(RESTManager): + """Manage signature callback transport to the OpenAEV backend. + + Handles payload validation, auto-chunking, and retry with exponential backoff. + """ + + _path = "/injects" + _obj_cls = Signature + + DEFAULT_MAX_PAYLOAD_SIZE = 1_048_576 # 1 MiB + MAX_RETRIES = 3 + RETRY_DELAYS = (1, 2, 4) + + _CHUNK_METADATA_RESERVE = len( + ',"chunk_index":99999,"total_chunks":99999,"phase":"execution_complete_extended"' + ) + + def __init__(self, openaev: "Any", parent: "Any" = None) -> None: + """Initialize the signature API manager. + + Args: + openaev: The OpenAEV client instance. + parent: Optional parent REST object for nested managers. + """ + super().__init__(openaev, parent) + self._max_payload_size = self.DEFAULT_MAX_PAYLOAD_SIZE + self._logger = logging.getLogger(__name__) + + @property + def max_payload_size(self) -> int: + """Maximum payload size in bytes before auto-chunking triggers.""" + return self._max_payload_size + + @max_payload_size.setter + def max_payload_size(self, value: int) -> None: + self._max_payload_size = value + + @property + def logger(self) -> logging.Logger: + """Logger instance used for transmission diagnostics.""" + return self._logger + + @logger.setter + def logger(self, value: logging.Logger) -> None: + self._logger = value + + def send_signatures( + self, + inject_id: str, + phase: str, + signatures: dict[str, Any], + ) -> None: + """Send compiled signatures to the inject callback endpoint. + + Auto-chunks payloads exceeding max_payload_size and retries on 5xx errors. + + Args: + inject_id: Inject UUID. + phase: Execution phase (e.g. 'execution_complete'). + signatures: Full signatures dict (canonical or flat, grouped on the fly). + + Raises: + SignatureTransmissionError: Validation failed, 4xx hit, or retries exhausted. + """ + self._logger.debug("send_signatures inject_id=%s phase=%s", inject_id, phase) + signatures = self._normalize_signature_payload(signatures) + payload = self._build_callback_payload(signatures, phase=phase) + + serialized = json.dumps(payload, separators=(",", ":")).encode() + + if len(serialized) <= self._max_payload_size: + self._send_with_retry(inject_id, payload) + else: + self._send_chunked(inject_id, payload["expectation_signature"], phase=phase) + + def _build_callback_payload( + self, + signatures: dict[str, Any], + *, + phase: str | None = None, + chunk_index: int | None = None, + total_chunks: int | None = None, + ) -> dict[str, Any]: + """Validate and wrap signatures in the strict callback envelope. + + Args: + signatures: The inner signatures body, already normalised. + phase: Execution phase string (e.g. 'execution_complete'). + chunk_index: 0-based index when chunking, None for single POSTs. + total_chunks: Chunk count when chunking, None for single POSTs. + + Returns: + The validated dict ready for wire transmission. + + Raises: + SignatureTransmissionError: Envelope failed Pydantic validation. + """ + try: + envelope = SignatureCallbackPayload.model_validate( + { + "expectation_signature": signatures, + "phase": phase, + "chunk_index": chunk_index, + "total_chunks": total_chunks, + } + ) + except ValidationError as ve: + raise SignatureTransmissionError( + error_message=f"Invalid signatures payload: {ve}", + ) from ve + return envelope.model_dump(mode="json", exclude_none=True) + + def _normalize_signature_payload( + self, signatures: dict[str, Any] + ) -> dict[str, Any]: + """Regroup signature_values by expectation_type within each target. + + Accepts flat or pre-grouped input and returns canonical grouped form. + + Args: + signatures: Raw signatures dict with any mix of flat and grouped entries. + + Returns: + New dict where every signature_values list is in canonical grouped form. + """ + targets = signatures.get("targets") + if not targets: + return signatures + + normalized_targets: list[dict[str, Any]] = [] + for target in targets: + sig_values = target.get("signature_values") + if not sig_values: + normalized_targets.append(target) + continue + + grouped: dict[str, list[dict[str, Any]]] = {} + order: list[str] = [] + + for entry in sig_values: + etype = entry.get("expectation_type") + if etype not in grouped: + grouped[etype] = [] + order.append(etype) + + if "values" in entry and isinstance(entry["values"], list): + grouped[etype].extend(entry["values"]) + else: + grouped[etype].append( + {k: v for k, v in entry.items() if k != "expectation_type"} + ) + + normalized_target = dict(target) + normalized_target["signature_values"] = [ + {"expectation_type": etype, "values": grouped[etype]} for etype in order + ] + normalized_targets.append(normalized_target) + + normalized = dict(signatures) + normalized["targets"] = normalized_targets + return normalized + + def _send_chunked( + self, inject_id: str, signatures: dict[str, Any], phase: str | None = None + ) -> None: + """Split targets across sequential POSTs, each tagged with chunk metadata. + + Args: + inject_id: Inject UUID for the callback path. + signatures: Normalised inner signatures body to partition. + phase: Execution phase forwarded to each chunk envelope. + + Raises: + SignatureTransmissionError: A single target alone exceeds max_payload_size. + """ + targets = signatures.get("targets", []) + if not targets: + payload = self._build_callback_payload(signatures, phase=phase) + size = len(json.dumps(payload, separators=(",", ":")).encode()) + if size > self._max_payload_size: + self._logger.warning( + "Payload of %d bytes exceeds max_payload_size %d but has no " + "'targets' key to chunk on; sending unchunked", + size, + self._max_payload_size, + ) + self._send_with_retry(inject_id, payload) + return + + budget = max(self._max_payload_size - self._CHUNK_METADATA_RESERVE, 0) + chunks: list[list[Any]] = [] + current_chunk: list[Any] = [] + + for target in targets: + candidate = current_chunk + [target] + size = len( + json.dumps( + {"expectation_signature": {"targets": candidate}}, + separators=(",", ":"), + ).encode() + ) + + if size <= budget: + current_chunk.append(target) + continue + + if not current_chunk: + raise SignatureTransmissionError( + error_message=( + f"Single target payload of {size} bytes exceeds " + f"max_payload_size {self._max_payload_size}; cannot chunk further" + ), + ) + + chunks.append(current_chunk) + current_chunk = [target] + solo_size = len( + json.dumps( + {"expectation_signature": {"targets": [target]}}, + separators=(",", ":"), + ).encode() + ) + if solo_size > budget: + raise SignatureTransmissionError( + error_message=( + f"Single target payload of {solo_size} bytes exceeds " + f"max_payload_size {self._max_payload_size}; cannot chunk further" + ), + ) + + if current_chunk: + chunks.append(current_chunk) + + total_chunks = len(chunks) + for idx, chunk_targets in enumerate(chunks): + chunk_payload = self._build_callback_payload( + {"targets": chunk_targets}, + phase=phase, + chunk_index=idx, + total_chunks=total_chunks, + ) + self._send_with_retry(inject_id, chunk_payload) + + @exc.on_http_error(exc.OpenAEVUpdateError) + def callback( + self, inject_id: str, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: + """Post signature payload to the inject callback endpoint. + + Args: + inject_id: Inject UUID. + data: Validated payload dict to send. + **kwargs: Additional arguments forwarded to http_post. + + Returns: + The parsed response from the backend. + """ + path = f"{self.path}/{inject_id}/callback" + result = self.openaev.http_post(path, post_data=data, **kwargs) + return result + + def _send_with_retry( + self, inject_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + """Retry callback() with exponential backoff on 5xx, immediate raise on 4xx. + + Args: + inject_id: Inject UUID for the callback path. + payload: Validated payload dict to send. + + Returns: + The successful response from callback(). + + Raises: + SignatureTransmissionError: 4xx error or all retries exhausted. + """ + from pyoaev.exceptions import OpenAEVError + + last_error: Exception | None = None + + for attempt in range(self.MAX_RETRIES + 1): + try: + return self.callback(inject_id, payload) + except OpenAEVError as ex: + status = ex.response_code + if status and 400 <= status < 500: + body_str = "" + if ex.response_body: + body_str = ex.response_body.decode(errors="replace") + self._logger.error( + "Client error %d sending signatures: %s", + status, + body_str or ex.error_message, + ) + raise SignatureTransmissionError( + error_message=f"Client error {status}: {ex.error_message}", + response_code=status, + response_body=ex.response_body, + ) from ex + + last_error = ex + if attempt < self.MAX_RETRIES: + delay = self.RETRY_DELAYS[attempt] + self._logger.warning( + "Retry %d/%d after %ds (HTTP %s): %s", + attempt + 1, + self.MAX_RETRIES, + delay, + status, + ex.error_message, + ) + time.sleep(delay) + + raise SignatureTransmissionError( + error_message=f"All {self.MAX_RETRIES} retries exhausted", + response_code=getattr(last_error, "response_code", None), + response_body=getattr(last_error, "response_body", None), + ) diff --git a/pyoaev/client.py b/pyoaev/client.py index 450e99d..e478d43 100644 --- a/pyoaev/client.py +++ b/pyoaev/client.py @@ -75,6 +75,7 @@ def __init__( self.payload = apis.PayloadManager(self) self.security_platform = apis.SecurityPlatformManager(self) self.inject_expectation_trace = apis.InjectExpectationTraceManager(self) + self.signature = apis.SignatureApiManager(self) self.tag = apis.TagManager(self) @staticmethod diff --git a/pyoaev/contracts/contract_config.py b/pyoaev/contracts/contract_config.py index 161008a..3cba107 100644 --- a/pyoaev/contracts/contract_config.py +++ b/pyoaev/contracts/contract_config.py @@ -57,6 +57,7 @@ class ContractOutputType(str, Enum): Sid: str = "sid" Vulnerability: str = "vulnerability" AccountWithPasswordNotRequired: str = "account_with_password_not_required" + ExpectationSignature: str = "expectation_signature" AsreproastableAccount: str = "asreproastable_account" KerberoastableAccount: str = "kerberoastable_account" @@ -152,12 +153,14 @@ class Contract: config: ContractConfig manual: bool variables: List[ContractVariable] = field( - default_factory=lambda: [ - VariableHelper.user_variable(), - VariableHelper.exercise_variable(), - VariableHelper.team_variable(), - ] - + VariableHelper.uri_variables() + default_factory=lambda: ( + [ + VariableHelper.user_variable(), + VariableHelper.exercise_variable(), + VariableHelper.team_variable(), + ] + + VariableHelper.uri_variables() + ) ) contract_attack_patterns_external_ids: List[str] = field(default_factory=list) contract_vulnerability_external_ids: List[str] = field(default_factory=list) @@ -212,7 +215,6 @@ def get_type(self) -> str: @dataclass class ContractText(ContractCardinalityElement): - defaultValue: str = "" @property @@ -253,7 +255,6 @@ def get_type(self) -> str: @dataclass class ContractTextArea(ContractCardinalityElement): - defaultValue: str = "" richText: bool = False @@ -264,7 +265,6 @@ def get_type(self) -> str: @dataclass class ContractCheckbox(ContractElement): - defaultValue: bool = False @property @@ -274,7 +274,6 @@ def get_type(self) -> str: @dataclass class ContractAttachment(ContractCardinalityElement): - @property def get_type(self) -> str: return ContractFieldType.Attachment.value @@ -292,7 +291,6 @@ def get_type(self) -> str: @dataclass class ContractSelect(ContractCardinalityElement): - choices: dict[str, str] = None @property @@ -320,7 +318,6 @@ def get_type(self) -> str: @dataclass class ContractPayload(ContractCardinalityElement): - @property def get_type(self) -> str: return ContractFieldType.Payload.value diff --git a/pyoaev/exceptions.py b/pyoaev/exceptions.py index 16f8a78..046405a 100644 --- a/pyoaev/exceptions.py +++ b/pyoaev/exceptions.py @@ -180,6 +180,12 @@ class OpenAEVCreateError(OpenAEVError): pass +class SignatureTransmissionError(OpenAEVError): + """Signatures didn't make it. Validation rejected them, 4xx slammed the door, or retries ran dry.""" + + pass + + class ConfigurationError(OpenAEVError): pass @@ -216,4 +222,5 @@ def wrapped_f(*args: Any, **kwargs: Any) -> Any: "OpenAEVListError", "OpenAEVGetError", "OpenAEVUpdateError", + "SignatureTransmissionError", ] diff --git a/pyoaev/signatures/__init__.py b/pyoaev/signatures/__init__.py index e69de29..75acc02 100644 --- a/pyoaev/signatures/__init__.py +++ b/pyoaev/signatures/__init__.py @@ -0,0 +1,32 @@ +from pyoaev.signatures.models import ( + CloudInjectorConfig, + ExpectationSignatureGroup, + ExternalInjectorConfig, + InjectorConfig, + NetworkInjectorConfig, + SignatureCallbackPayload, + SignaturePayload, + SignatureTarget, + SignatureValue, + TargetSignatures, + build_network_configs, +) +from pyoaev.signatures.signature_manager import SignatureManager +from pyoaev.signatures.types import MatchTypes, SignatureTypes + +__all__ = [ + "CloudInjectorConfig", + "ExpectationSignatureGroup", + "ExternalInjectorConfig", + "InjectorConfig", + "MatchTypes", + "NetworkInjectorConfig", + "SignatureCallbackPayload", + "SignatureManager", + "SignaturePayload", + "SignatureTarget", + "SignatureTypes", + "SignatureValue", + "TargetSignatures", + "build_network_configs", +] diff --git a/pyoaev/signatures/models.py b/pyoaev/signatures/models.py new file mode 100644 index 0000000..3fe285b --- /dev/null +++ b/pyoaev/signatures/models.py @@ -0,0 +1,227 @@ +"""Pydantic schemas pinning every shape SignatureManager touches.""" + +import ipaddress +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class SignatureValue(BaseModel): + """One signature observation: a type and the value it carries.""" + + model_config = ConfigDict(extra="allow") + + signature_type: str + signature_value: str + + +class ExpectationSignatureGroup(BaseModel): + """Values bound to a single expectation type (DETECTION, PREVENTION, ...).""" + + model_config = ConfigDict(extra="allow") + + expectation_type: str + values: list[SignatureValue] + + +class SignatureTarget(BaseModel): + """Target identity on the wire. Three fields, all mandatory, no exceptions.""" + + model_config = ConfigDict(extra="allow") + + agent: str + asset: str + asset_group: str + + +class TargetSignatures(BaseModel): + """A target plus everything observed about it, grouped by expectation.""" + + model_config = ConfigDict(extra="allow") + + signature_target: SignatureTarget + signature_values: list[ExpectationSignatureGroup] + + +class SignaturePayload(BaseModel): + """Inner ``signatures`` body: a list of targets, nothing else.""" + + model_config = ConfigDict(extra="allow") + + targets: list[TargetSignatures] + + +class SignatureCallbackPayload(BaseModel): + """Outer POST envelope. Pure ``{signatures}`` when unchunked, plus chunk fields when split.""" + + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + expectation_signature: SignaturePayload + phase: str | None = None + chunk_index: int | None = None + total_chunks: int | None = None + + +class PreExecutionSignature(BaseModel): + """Pre-execution data dump. Field set varies by category: network, cloud, external.""" + + model_config = ConfigDict(extra="allow") + + # Timing always emitted at call time. + start_time: str | None = None + + # Network identity + source_ipv4: str | None = None + source_ipv6: str | None = None + target_ipv4: str | None = None + target_ipv6: str | None = None + target_hostname: str | None = None + + # Cloud identity + cloud_provider: str | None = None + cloud_account_id: str | None = None + cloud_region: str | None = None + target_service: str | None = None + + # External + query: str | None = None + + +class PostExecutionSignature(PreExecutionSignature): + """Post-execution view: pre-execution fields plus outcome, end_time, and any partial results.""" + + end_time: str | None = None + execution_status: str | None = None + partial_results: list[str] | None = None + + +class ToolErrorInfo(BaseModel): + """Crash report. Non-zero exit code and a timestamp if the tool left one behind.""" + + model_config = ConfigDict(extra="allow") + + exit_code: int = 0 + crash_timestamp: str | None = None + + +class ToolTimeoutInfo(BaseModel): + """Timeout report. Whatever partial loot was rescued before the kill signal.""" + + model_config = ConfigDict(extra="allow") + + partial_results: list[str] = [] + + +class ToolOutput(BaseModel): + """Whatever the tool spat out: status, error info, timeout info, or injector extras.""" + + model_config = ConfigDict(extra="allow") + + status: str | None = None + error_info: ToolErrorInfo | None = None + timeout_info: ToolTimeoutInfo | None = None + extra_signatures: dict[str, Any] | None = None + + +class NetworkInjectorConfig(BaseModel): + """A single network target. Exactly one of ``target_ipv4``, ``target_ipv6``, or ``target_hostname``.""" + + model_config = ConfigDict(extra="forbid") + + target_ipv4: str | None = None + target_ipv6: str | None = None + target_hostname: str | None = None + + +class CloudInjectorConfig(BaseModel): + """A single cloud target row. One config per region; fan out by passing a list.""" + + model_config = ConfigDict(extra="forbid") + + cloud_provider: str + cloud_account_id: str + cloud_region: str + target_service: str | None = None + + +class ExternalInjectorConfig(BaseModel): + """A single external scan target (e.g. Shodan): a query against an asset.""" + + model_config = ConfigDict(extra="forbid") + + query: str + target_ipv4: str | None = None + target_hostname: str | None = None + + +InjectorConfig = NetworkInjectorConfig | CloudInjectorConfig | ExternalInjectorConfig + + +# --------------------------------------------------------------------------- +# Builders. Cheap helpers to turn raw injector input into typed configs. +# --------------------------------------------------------------------------- + + +def _classify_network_target(value: str) -> NetworkInjectorConfig: + """Decide whether ``value`` is an IPv4, IPv6, or hostname and wrap it.""" + try: + addr = ipaddress.ip_address(value) + except ValueError: + return NetworkInjectorConfig(target_hostname=value) + if isinstance(addr, ipaddress.IPv4Address): + return NetworkInjectorConfig(target_ipv4=value) + return NetworkInjectorConfig(target_ipv6=value) + + +def build_network_configs( + targets: list[str | dict[str, Any] | NetworkInjectorConfig], +) -> list[NetworkInjectorConfig]: + """Forge a list of `NetworkInjectorConfig` from a heterogeneous target list. + + Each item is one distinct asset. Accepted shapes: + + - `NetworkInjectorConfig`: passed through unchanged. + - `dict`: validated against :class:`NetworkInjectorConfig`. + - `str`: auto-classified into IPv4 / IPv6 / hostname. + + Args: + targets: Raw target list straight out of the injector. + + Returns: + One `NetworkInjectorConfig` per input target, order preserved. + + Raises: + TypeError: An item is not one of the accepted shapes. + ValidationError: A dict item fails the one-identity invariant. + """ + configs: list[NetworkInjectorConfig] = [] + for target in targets: + if isinstance(target, NetworkInjectorConfig): + configs.append(target) + elif isinstance(target, dict): + configs.append(NetworkInjectorConfig(**target)) + elif isinstance(target, str): + configs.append(_classify_network_target(target)) + else: + raise TypeError(f"unsupported network target type: {type(target).__name__}") + return configs + + +__all__ = [ + "SignatureValue", + "ExpectationSignatureGroup", + "SignatureTarget", + "TargetSignatures", + "SignaturePayload", + "SignatureCallbackPayload", + "PreExecutionSignature", + "PostExecutionSignature", + "ToolErrorInfo", + "ToolTimeoutInfo", + "ToolOutput", + "NetworkInjectorConfig", + "CloudInjectorConfig", + "ExternalInjectorConfig", + "InjectorConfig", + "build_network_configs", +] diff --git a/pyoaev/signatures/signature_manager.py b/pyoaev/signatures/signature_manager.py new file mode 100644 index 0000000..db3d274 --- /dev/null +++ b/pyoaev/signatures/signature_manager.py @@ -0,0 +1,301 @@ +"""Signature lifecycle for OpenAEV injectors: compile pre, merge post, ship to backend.""" + +import logging +import os +import socket +import subprocess +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from pydantic import ValidationError + +from pyoaev.exceptions import OpenAEVError +from pyoaev.signatures.models import ( + CloudInjectorConfig, + ExpectationSignatureGroup, + ExternalInjectorConfig, + InjectorConfig, + NetworkInjectorConfig, + PostExecutionSignature, + PreExecutionSignature, + SignaturePayload, + SignatureTarget, + SignatureValue, + TargetSignatures, + ToolOutput, +) + +if TYPE_CHECKING: + from pyoaev.client import OpenAEV + + +class SignatureManager: + """End-to-end signature pipeline: compile, merge, transmit. One class, three jobs.""" + + DEFAULT_MAX_PAYLOAD_SIZE = 1_048_576 # 1 MiB + + def __init__( + self, + client: "OpenAEV", + logger: logging.Logger | None = None, + max_payload_size: int = DEFAULT_MAX_PAYLOAD_SIZE, + ) -> None: + self.client = client + self.logger = logger or logging.getLogger(__name__) + self.max_payload_size = max_payload_size + self._cached_ipv4: str | None = None + self._cached_ipv6: str | None = None + + def _utcnow(self) -> datetime: + """Current UTC time. Carved out so tests can pin the clock.""" + return datetime.now(timezone.utc) + + def compile_pre_execution_signatures( + self, + config: InjectorConfig | list[InjectorConfig], + ) -> dict[str, Any] | list[dict[str, Any]]: + """Build pre-execution signature dicts from one or more typed injector configs. + + The category is carried by the config type itself + (:class:`NetworkInjectorConfig`, :class:`CloudInjectorConfig`, + :class:`ExternalInjectorConfig`), so no separate ``category`` flag is needed. + + Args: + config: A single injector config or a homogeneous list of them. + Multi-target injects must be expressed as a list. + + Returns: + One dict when a single config is given, otherwise a list of dicts in + input order. + + Raises: + ValueError: Empty list, or mixed config types in a single call. + TypeError: Unknown injector config type. + """ + configs = list(config) if isinstance(config, list) else [config] + if not configs: + raise ValueError( + "compile_pre_execution_signatures requires at least one config" + ) + + first_type = type(configs[0]) + if any(type(c) is not first_type for c in configs): + raise ValueError( + "compile_pre_execution_signatures does not mix injector config types; " + f"got {sorted({type(c).__name__ for c in configs})}" + ) + + start_time = self._utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + results = [self._compile_one(cfg, start_time) for cfg in configs] + return results[0] if len(results) == 1 else results + + def _compile_one(self, config: InjectorConfig, start_time: str) -> dict[str, Any]: + """Project a single injector config into a flat pre-execution signature dict. + + Common pipeline for every category: + 1. Seed the base dict with ``start_time`` and category-specific context + (network gets resolved source IPs; cloud/external add nothing). + 2. Layer the config's own fields on top. + 3. Run it through :class:`PreExecutionSignature` for validation + and emit JSON-ready output stripped of ``None``\\ s. + """ + base: dict[str, Any] = {"start_time": start_time} + base.update(self._source_context(config)) + base.update(config.model_dump(exclude_none=True)) + return PreExecutionSignature(**base).model_dump(mode="json", exclude_none=True) + + def _source_context(self, config: InjectorConfig) -> dict[str, Any]: + """Return the source identity bits injected for the config's category. + + Only network signatures need the running container's source IPs; + cloud and external rows have no source identity to carry. + """ + if isinstance(config, NetworkInjectorConfig): + return { + "source_ipv4": self.resolve_container_ip(), + "source_ipv6": self._cached_ipv6, + } + if isinstance(config, (CloudInjectorConfig, ExternalInjectorConfig)): + return {} + raise TypeError(f"unsupported injector config type: {type(config).__name__}") + + def compile_post_execution_signatures( + self, + pre_signatures: dict[str, Any] | list[dict[str, Any]], + tool_output: dict[str, Any], + ) -> dict[str, Any] | list[dict[str, Any]]: + """Merge pre-execution dicts with the tool's verdict into post-execution dicts. + + Args: + pre_signatures: One pre-execution dict or a list of them. + tool_output: Tool result with optional `error_info` / `timeout_info` / `status`. + + Returns: + Same shape as `pre_signatures`, now carrying `end_time` and `execution_status`. + """ + if isinstance(pre_signatures, list): + return [self._merge_post(sig, tool_output) for sig in pre_signatures] + return self._merge_post(pre_signatures, tool_output) + + def _merge_post( + self, pre_sig: dict[str, Any], tool_output: dict[str, Any] + ) -> dict[str, Any]: + try: + tool = ToolOutput.model_validate(tool_output or {}) + except ValidationError as exc: + raise OpenAEVError( + error_message=f"Invalid tool_output: {exc}", + ) from exc + + post = PostExecutionSignature.model_validate(pre_sig) + now = self._utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + + if tool.error_info and tool.error_info.exit_code != 0: + post.execution_status = "failed" + post.end_time = tool.error_info.crash_timestamp or now + elif tool.timeout_info: + post.execution_status = "timeout" + post.end_time = now + if tool.timeout_info.partial_results: + post.partial_results = tool.timeout_info.partial_results + elif tool.status == "partial": + post.execution_status = "partial" + post.end_time = now + else: + post.execution_status = "success" + post.end_time = now + + merged = post.model_dump(mode="json", exclude_none=True) + + if tool.extra_signatures: + merged.update(tool.extra_signatures) + + return merged + + def build_payload( + self, + post_signatures: dict[str, Any] | list[dict[str, Any]], + targets_meta: dict[str, str] | list[dict[str, str]], + expectation_type: str = "DETECTION", + ) -> dict[str, Any]: + """Build the nested wire payload from flat post-execution signatures. + + Bridges the gap between compile_post_execution_signatures output (flat dicts) + and send_signatures input (nested wire format). + + Args: + post_signatures: A single post-execution dict or a list (multi-target). + targets_meta: Target metadata dict(s) with keys like agent, asset, asset_group. + expectation_type: The expectation type label (e.g. 'DETECTION', 'PREVENTION'). + + Returns: + A payload dict ready for send_signatures. + """ + if isinstance(post_signatures, dict): + post_signatures = [post_signatures] + if isinstance(targets_meta, dict): + targets_meta = [targets_meta] * len(post_signatures) + + targets = [] + for sig, meta in zip(post_signatures, targets_meta): + values = [ + SignatureValue(signature_type=k, signature_value=str(v)) + for k, v in sig.items() + ] + targets.append( + TargetSignatures( + signature_target=SignatureTarget(**meta), + signature_values=[ + ExpectationSignatureGroup( + expectation_type=expectation_type, values=values + ) + ], + ) + ) + + return SignaturePayload(targets=targets).model_dump() + + def send_signatures( + self, + inject_id: str, + phase: str, + signatures: dict[str, Any], + ) -> None: + """Ship signatures to the callback endpoint via the Signature API manager. + + Delegates transport (retry, chunking, validation) to ``client.signature``. + + Args: + inject_id: Inject UUID. + phase: Execution phase. + signatures: Full signatures dict, canonical or flat, both grouped on the fly. + + Raises: + SignatureTransmissionError: Validation failed, 4xx hit, or retries exhausted. + """ + self.client.signature.max_payload_size = self.max_payload_size + self.client.signature.logger = self.logger + self.client.signature.send_signatures(inject_id, phase, signatures) + + def resolve_container_ip(self) -> str: + """Sniff the container's primary IPv4. Env var, hostname, then ``hostname -i``. + + Returns: + The IPv4 string, or ``'unknown'`` with a single warning when all strategies fail. + """ + if self._cached_ipv4: + return self._cached_ipv4 + + env_ip = os.environ.get("CONTAINER_IP") + if env_ip: + self._cached_ipv4 = env_ip + self._resolve_ipv6() + return env_ip + + try: + ip = socket.gethostbyname(socket.gethostname()) + if ip and ip != "127.0.0.1": + self._cached_ipv4 = ip + self._resolve_ipv6() + return ip + except (socket.gaierror, OSError): + pass + + try: + result = subprocess.run( + ["hostname", "-i"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + ip = result.stdout.strip().split()[0] + if ip and ip != "127.0.0.1": + self._cached_ipv4 = ip + self._resolve_ipv6() + return ip + except (OSError, RuntimeError, subprocess.TimeoutExpired): + pass + + self.logger.warning("Could not resolve container IP; returning 'unknown'") + self._cached_ipv4 = "unknown" + return "unknown" + + def _resolve_ipv6(self) -> None: + """Best-effort IPv6 sniff. Silent on failure, no exceptions escape.""" + try: + infos = socket.getaddrinfo( + socket.gethostname(), None, socket.AF_INET6, socket.SOCK_STREAM + ) + for info in infos: + addr = info[4][0] + if ( + isinstance(addr, str) + and addr + and not addr.startswith("::1") + and not addr.startswith("fe80") + ): + self._cached_ipv6 = addr + return + except (socket.gaierror, OSError): + pass diff --git a/pyoaev/signatures/types.py b/pyoaev/signatures/types.py index 4a479a1..88fd9b6 100644 --- a/pyoaev/signatures/types.py +++ b/pyoaev/signatures/types.py @@ -15,3 +15,8 @@ class SignatureTypes(str, Enum): SIG_TYPE_TARGET_HOSTNAME_ADDRESS = "target_hostname_address" SIG_TYPE_START_DATE = "start_date" SIG_TYPE_END_DATE = "end_date" + SIG_TYPE_CLOUD_PROVIDER = "cloud_provider" + SIG_TYPE_CLOUD_ACCOUNT_ID = "cloud_account_id" + SIG_TYPE_CLOUD_REGION = "cloud_region" + SIG_TYPE_TARGET_SERVICE = "target_service" + SIG_TYPE_QUERY = "query" diff --git a/pyproject.toml b/pyproject.toml index 01929e3..4c65201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,10 @@ dependencies = [ # OpenAEV, "requests-toolbelt (>=1.0.0,<1.1.0)", "dataclasses-json (>=0.6.4,<0.7.0)", - "thefuzz (>=0.22,<0.23)" + "thefuzz (>=0.22,<0.23)", + # Ugly fix, we need to fix the CI to use: [dev] + "pytest (>=9.0.0,<9.1.0)", + "pytest-bdd (>=8.1.0, <8.2.0)", ] [project.optional-dependencies] @@ -54,7 +57,9 @@ dev = [ "pre-commit (>=4.6.0,<4.7.0)", "types-python-dateutil (>=2.9.0,<2.10.0)", "wheel (>=0.47.0,<0.48.0)", - "coverage>=7.13.5" + "coverage>=7.13.5", + "pytest (>=9.0.0,<9.1.0)", + "pytest-bdd (>=8.1.0, <8.2.0)", ] doc = [ "autoapi (>=2.0.1,<2.1.0)", diff --git a/test/signatures/constraints/signature_manager_post_execution_constraints.feature b/test/signatures/constraints/signature_manager_post_execution_constraints.feature new file mode 100644 index 0000000..6d163bf --- /dev/null +++ b/test/signatures/constraints/signature_manager_post_execution_constraints.feature @@ -0,0 +1,33 @@ +Feature: SignatureManager post-execution constraints + As an injector using the OpenAEV client + I want post-execution compilation to handle failure and timeout edge cases + So that incomplete executions still produce valid signature records + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + And a pre_signatures dict containing: + | key | value | + | source_ipv4 | 172.17.0.2 | + | target_ipv4 | 10.0.0.1 | + | target_hostname | host-a.internal | + | start_time | 2024-06-26T06:00:00Z | + + Scenario: Tool crash sets execution_status to failed and uses crash timestamp as end_time + Given a tool_output containing error_info with exit_code=1 and crash_timestamp="2024-06-26T06:05:00Z" + When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + Then execution_status equals "failed" + And end_time equals "2024-06-26T06:05:00Z" + And all pre-execution fields from pre_signatures are present and unchanged in the returned dict + + Scenario: Timeout sets execution_status to timeout and includes available partial results + Given a tool_output containing timeout_info with partial_results=["result-A", "result-B"] + When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + Then execution_status equals "timeout" + And the returned dict contains the partial results ["result-A", "result-B"] from timeout_info + And all pre-execution fields from pre_signatures are present and unchanged in the returned dict + + Scenario: Timeout with no partial results still sets execution_status to timeout + Given a tool_output containing timeout_info with no partial results available + When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + Then execution_status equals "timeout" + And all pre-execution fields from pre_signatures are present and unchanged in the returned dict diff --git a/test/signatures/constraints/signature_manager_pre_execution_constraints.feature b/test/signatures/constraints/signature_manager_pre_execution_constraints.feature new file mode 100644 index 0000000..e37a1e5 --- /dev/null +++ b/test/signatures/constraints/signature_manager_pre_execution_constraints.feature @@ -0,0 +1,15 @@ +Feature: SignatureManager pre-execution constraints + As an injector using the OpenAEV client + I want pre-execution compilation to handle timing edge cases correctly + So that signatures always reflect the actual moment of execution + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + + Scenario: start_time is generated at method-call time not at class instantiation time + Given a SignatureManager that was instantiated at timestamp T0 + And 5 seconds elapse after instantiation + And a NetworkInjectorConfig with target_ipv4="192.168.1.10" + When I call compile_pre_execution_signatures with the config at timestamp T1 + Then the start_time in the returned dict equals T1 within 1 second tolerance + And start_time does not equal T0 diff --git a/test/signatures/constraints/signature_manager_transmission_constraints.feature b/test/signatures/constraints/signature_manager_transmission_constraints.feature new file mode 100644 index 0000000..0fcf4f1 --- /dev/null +++ b/test/signatures/constraints/signature_manager_transmission_constraints.feature @@ -0,0 +1,45 @@ +Feature: SignatureManager transmission constraints + As an injector using the OpenAEV client + I want signature transmission to handle large payloads, transient errors, and client errors correctly + So that signatures are reliably delivered even under adverse conditions + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + + Scenario: Payload exceeding MAX_PAYLOAD_SIZE is auto-chunked with chunk metadata + Given a compiled payload whose serialised size exceeds MAX_PAYLOAD_SIZE by at least a factor of 2 + And the backend responds with HTTP 200 + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then the payload is sent as multiple sequential POST requests to /injects/inject-abc-001/callback + And each POST request body contains chunk_index as a 0-based integer + And each POST request body contains total_chunks as a positive integer matching the total number of chunks sent + And each POST request body contains only "signatures", "chunk_index" and "total_chunks" at the top level + And the union of targets across all POST requests equals the original target set + And no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes + + Scenario: HTTP 5xx response triggers exponential backoff retry for up to 3 additional attempts + Given a compiled post-execution payload for inject_id "inject-abc-001" + And the backend responds with HTTP 503 on every attempt + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then send_signatures sends a total of 4 POST requests to /injects/inject-abc-001/callback + And a WARNING log message containing the retry attempt number is emitted before each of the 3 retry attempts + And the wait before attempt 2 is 1 second + And the wait before attempt 3 is 2 seconds + And the wait before attempt 4 is 4 seconds + And a SignatureTransmissionError is raised after all retries are exhausted + + Scenario: HTTP 4xx response raises an exception immediately with no retries and no sleep + Given a compiled post-execution payload for inject_id "inject-abc-001" + And the backend responds with HTTP 400 and body '{"error": "bad request"}' + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then only 1 POST request is sent to /injects/inject-abc-001/callback + And an ERROR log message containing status code 400 and the response body is emitted + And an exception is raised immediately + And no sleep or wait occurs before the exception is raised + + Scenario: resolve_container_ip returns unknown and emits exactly one warning when all strategies fail + Given all IP resolution strategies are mocked to fail + When I call resolve_container_ip + Then the returned value is the string "unknown" + And exactly 1 WARNING log message is emitted + And no exception propagates from resolve_container_ip diff --git a/test/signatures/features/signature_manager_backward_compat.feature b/test/signatures/features/signature_manager_backward_compat.feature new file mode 100644 index 0000000..a683ac4 --- /dev/null +++ b/test/signatures/features/signature_manager_backward_compat.feature @@ -0,0 +1,17 @@ +Feature: SignatureManager backward compatibility with existing pyoaev consumers + As a maintainer of the OpenAEV client library + I want SignatureManager to integrate without breaking any existing code + So that all current injectors continue to work unchanged after the merge + + Scenario: Injectors that do not call SignatureManager experience no behavioural change + Given an injector that does not call any SignatureManager method + When that injector executes its normal workflow + Then its behaviour is identical to its behaviour before SignatureManager was introduced + And no import errors, attribute errors, or unexpected exceptions occur + + Scenario: Existing public import paths in pyoaev remain unchanged + Given the pyoaev package with SignatureManager merged + When existing code imports InjectManager, SignatureType, SignatureTypes, SignatureMatch, or Expectation using their current import paths + Then all imports resolve without error + And all constructor signatures remain unchanged + And all public method signatures remain unchanged diff --git a/test/signatures/features/signature_manager_post_execution.feature b/test/signatures/features/signature_manager_post_execution.feature new file mode 100644 index 0000000..856b1dd --- /dev/null +++ b/test/signatures/features/signature_manager_post_execution.feature @@ -0,0 +1,31 @@ +Feature: SignatureManager post-execution signature compilation + As an injector using the OpenAEV client + I want to merge execution results into pre-execution signatures + So that each inject has a complete signature record including outcome and timing + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + And a pre_signatures dict containing: + | key | value | + | source_ipv4 | 172.17.0.2 | + | target_ipv4 | 10.0.0.1 | + | target_hostname | host-a.internal | + | start_time | 2024-06-26T06:00:00Z | + + Scenario: Successful execution merges end_time and execution_status into pre-execution fields + Given a tool_output indicating successful completion with no errors and no timeout + When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + Then the returned dict contains every key-value pair from pre_signatures unchanged + And the returned dict contains end_time as a UTC ISO 8601 string + And end_time is chronologically greater than or equal to start_time "2024-06-26T06:00:00Z" + And the returned dict contains execution_status equal to "success" + + + Scenario: Multi-target pre-signatures merge into a list of post-signatures + Given the pre_signatures is replaced by a list of 3 dicts each with a distinct target_ipv4 + And a tool_output indicating successful completion with no errors and no timeout + When I call compile_post_execution_signatures with the pre_signatures dict and tool_output + Then the returned value is a list of exactly 3 dicts + And every dict in the returned list contains execution_status equal to "success" + And every dict in the returned list contains end_time as a UTC ISO 8601 string + And every dict in the returned list preserves its original target_ipv4 and source_ipv4 fields diff --git a/test/signatures/features/signature_manager_pre_execution.feature b/test/signatures/features/signature_manager_pre_execution.feature new file mode 100644 index 0000000..1ba3ed5 --- /dev/null +++ b/test/signatures/features/signature_manager_pre_execution.feature @@ -0,0 +1,91 @@ +Feature: SignatureManager pre-execution signature compilation + As an injector using the OpenAEV client + I want to compile category-specific pre-execution signatures + So that each inject has a correct, typed signature payload before execution begins + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + + Scenario: Network category returns required IP and timing fields and no cloud or query fields + Given a NetworkInjectorConfig with target_ipv4="192.168.1.10" + And the running container has a resolvable IPv4 address + When I call compile_pre_execution_signatures with the config + Then the returned dict contains source_ipv4 as a non-empty valid IPv4 address string + And the returned dict contains start_time as a UTC ISO 8601 string + And the returned dict contains target_ipv4 equal to "192.168.1.10" + But the returned dict does not contain cloud_provider + And the returned dict does not contain cloud_account_id + And the returned dict does not contain cloud_region + And the returned dict does not contain target_service + And the returned dict does not contain query + + Scenario: Network hostname target returns hostname and no target_ipv4 + Given a NetworkInjectorConfig with target_hostname="target.example.com" + And the running container has a resolvable IPv4 address + When I call compile_pre_execution_signatures with the config + Then the returned dict contains target_hostname equal to "target.example.com" + And the returned dict contains source_ipv4 as a non-empty valid IPv4 address string + But the returned dict does not contain target_ipv4 + + Scenario: Cloud category returns required cloud identity fields and no IP fields + Given a CloudInjectorConfig with cloud_provider="aws", cloud_account_id="123456789012", cloud_region="eu-west-1", and target_service="ec2" + When I call compile_pre_execution_signatures with the config + Then the returned dict contains cloud_provider equal to "aws" + And the returned dict contains cloud_account_id equal to "123456789012" + And the returned dict contains cloud_region equal to "eu-west-1" + And the returned dict contains target_service equal to "ec2" + And the returned dict contains start_time as a UTC ISO 8601 string + But the returned dict does not contain source_ipv4 + And the returned dict does not contain source_ipv6 + And the returned dict does not contain target_ipv4 + And the returned dict does not contain target_ipv6 + + Scenario: External category returns scan target fields and no source IP + Given an ExternalInjectorConfig with target_ipv4="203.0.113.5" and query="port:22 os:linux" + When I call compile_pre_execution_signatures with the config + Then the returned dict contains target_ipv4 equal to "203.0.113.5" + And the returned dict contains query equal to "port:22 os:linux" + And the returned dict contains start_time as a UTC ISO 8601 string + But the returned dict does not contain source_ipv4 + + Scenario Outline: Network multi-target returns one dict per target with a shared source IP + Given a list of 3 NetworkInjectorConfig with target_ipv4 "10.0.0.1", "10.0.0.2", "10.0.0.3" + And the running container has a resolvable IPv4 address "172.17.0.2" + When I call compile_pre_execution_signatures with the config list + Then the return value is a list of exactly 3 dicts + And the dict at position contains target_ipv4 equal to "" + And the dict at position contains source_ipv4 equal to "172.17.0.2" + + Examples: + | index | target_ip | + | 0 | 10.0.0.1 | + | 1 | 10.0.0.2 | + | 2 | 10.0.0.3 | + + Scenario: All network multi-target dicts share the same source_ipv4 + Given a list of 3 NetworkInjectorConfig built from default IPv4 targets + And the running container has a resolvable IPv4 address + When I call compile_pre_execution_signatures with the config list + Then the return value is a list of 3 dicts + And all 3 dicts contain the same source_ipv4 value + + Scenario Outline: Cloud multi-region returns one dict per region with a shared account ID + Given a list of 3 CloudInjectorConfig with cloud_account_id="123456789012" and regions "us-east-1", "eu-west-1", "ap-southeast-1" + When I call compile_pre_execution_signatures with the config list + Then the return value is a list of exactly 3 dicts + And the dict at position contains cloud_region equal to "" + And the dict at position contains cloud_account_id equal to "123456789012" + + Examples: + | index | region | + | 0 | us-east-1 | + | 1 | eu-west-1 | + | 2 | ap-southeast-1 | + + Scenario: Builder classifies a mixed list of targets into typed configs + Given a raw mixed target list "10.0.0.1", "2001:db8::1", "target.example.com" + When I build network configs from the raw list + Then the builder returns 3 NetworkInjectorConfig + And the config at position 0 has target_ipv4 equal to "10.0.0.1" + And the config at position 1 has target_ipv6 equal to "2001:db8::1" + And the config at position 2 has target_hostname equal to "target.example.com" diff --git a/test/signatures/features/signature_manager_transmission.feature b/test/signatures/features/signature_manager_transmission.feature new file mode 100644 index 0000000..94e34bd --- /dev/null +++ b/test/signatures/features/signature_manager_transmission.feature @@ -0,0 +1,50 @@ +Feature: SignatureManager signature transmission and container IP resolution + As an injector using the OpenAEV client + I want to send compiled signatures to the backend and resolve my container's IP + So that inject results are recorded and IP-based signatures are accurate + + Background: + Given a SignatureManager initialised with constructor SignatureManager(client, logger) + + Scenario Outline: HTTP 2xx response is treated as successful transmission + Given a compiled post-execution payload for inject_id "inject-abc-001" + And the backend responds with HTTP + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then send_signatures completes without raising an exception + + Examples: + | status_code | + | 200 | + | 202 | + + Scenario: send_signatures posts to the inject callback with the agreed nested schema + Given a compiled payload with 1 target, expectation_type "DETECTION", signature_type "public_ip", signature_value "203.0.113.5" + And the backend responds with HTTP 200 + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then a POST request is sent to /injects/inject-abc-001/callback + And the POST request body contains signatures.targets as a list + And signatures.targets[0].signature_values[0].expectation_type equals "DETECTION" + And signatures.targets[0].signature_values[0].values[0].signature_type equals "public_ip" + And signatures.targets[0].signature_values[0].values[0].signature_value equals "203.0.113.5" + And signatures.targets[0] contains a signature_target key + + + Scenario Outline: resolve_container_ip returns a valid IPv4 in each supported execution environment + Given a SignatureManager running in a "" environment + When I call resolve_container_ip + Then the returned value is a non-empty valid IPv4 address string + + Examples: + | environment | + | Docker | + | Kubernetes | + | bare-metal | + + + Scenario: Payload schema groups signature values by expectation_type within each target + Given a compiled payload for 1 target with signatures of expectation_type "DETECTION" and expectation_type "PREVENTION" + And the backend responds with HTTP 200 + When I call send_signatures for inject_id "inject-abc-001" with phase "execution_complete" + Then the POST request body nests signature values under separate expectation_type entries within signatures.targets[0].signature_values + And the entry with expectation_type "DETECTION" contains only DETECTION signature values + And the entry with expectation_type "PREVENTION" contains only PREVENTION signature values diff --git a/test/signatures/test_signature_manager_backward_compat.py b/test/signatures/test_signature_manager_backward_compat.py new file mode 100644 index 0000000..27109e0 --- /dev/null +++ b/test/signatures/test_signature_manager_backward_compat.py @@ -0,0 +1,238 @@ +import inspect + +import pytest +from pytest_bdd import given, scenario, then, when + + +@scenario( + "features/signature_manager_backward_compat.feature", + "Injectors that do not call SignatureManager experience no behavioural change", +) +def test_injector_no_behavioural_change(): + pass + + +@scenario( + "features/signature_manager_backward_compat.feature", + "Existing public import paths in pyoaev remain unchanged", +) +def test_existing_public_import_paths_remain_unchanged(): + pass + + +# -------------------------------------------------- +# FIXTURE CONTEXT +# -------------------------------------------------- + + +@pytest.fixture +def context(): + return {} + + +# -------------------------------------------------- +# GIVEN +# -------------------------------------------------- + + +@given("an injector that does not call any SignatureManager method") +def injector_without_signature_manager(context, monkeypatch): + from pyoaev import OpenAEV + + monkeypatch.setattr( + OpenAEV, + "http_post", + lambda self, path, post_data=None, **kwargs: { + "path": path, + "post_data": post_data, + }, + ) + + client = OpenAEV("url", "token") + context["client"] = client + + +@given("the pyoaev package with SignatureManager merged") +def pyoaev_package_available(context): + context["pyoaev_package_available"] = True + + +# -------------------------------------------------- +# WHEN +# -------------------------------------------------- + + +@when("that injector executes its normal workflow") +def execute_injector_workflow(context): + result = None + caught_exception = None + try: + result = context["client"].inject.execution_callback( + "inject-id", + {"result": "ok"}, + ) + except Exception as exc: # pragma: no cover + caught_exception = exc + + context["workflow_result"] = result + context["workflow_exception"] = caught_exception + + +@when( + "existing code imports InjectManager, SignatureType, SignatureTypes, SignatureMatch, or Expectation using their current import paths" +) +def import_existing_public_paths(context): + imported = {} + import_error = None + try: + from pyoaev import OpenAEV + from pyoaev.apis.inject import InjectManager + from pyoaev.apis.inject_expectation.model.expectation import ( + DetectionExpectation, + Expectation, + PreventionExpectation, + ) + from pyoaev.signatures.signature_match import SignatureMatch + from pyoaev.signatures.signature_type import SignatureType + from pyoaev.signatures.types import MatchTypes, SignatureTypes + + imported = { + "OpenAEV": OpenAEV, + "InjectManager": InjectManager, + "SignatureTypes": SignatureTypes, + "MatchTypes": MatchTypes, + "SignatureType": SignatureType, + "SignatureMatch": SignatureMatch, + "Expectation": Expectation, + "DetectionExpectation": DetectionExpectation, + "PreventionExpectation": PreventionExpectation, + } + except ImportError as exc: # pragma: no cover + import_error = exc + + context["imported"] = imported + context["import_error"] = import_error + + +# -------------------------------------------------- +# THEN +# -------------------------------------------------- + + +@then( + "its behaviour is identical to its behaviour before SignatureManager was introduced" +) +def assert_behaviour_unchanged(context): + assert context["workflow_result"] == { + "path": "/injects/execution/callback/inject-id", + "post_data": {"result": "ok"}, + } + + +@then("no import errors, attribute errors, or unexpected exceptions occur") +def assert_no_exceptions(context): + assert context["workflow_exception"] is None + + +@then("all imports resolve without error") +def assert_imports_resolve(context): + assert context["import_error"] is None + + expected_symbols = { + "OpenAEV", + "InjectManager", + "SignatureTypes", + "MatchTypes", + "SignatureType", + "SignatureMatch", + "Expectation", + "DetectionExpectation", + "PreventionExpectation", + } + assert expected_symbols.issubset(set(context["imported"].keys())) + + +@then("all constructor signatures remain unchanged") +def assert_constructor_signatures(context): + imported = context["imported"] + + openaev_params = list(inspect.signature(imported["OpenAEV"]).parameters) + assert openaev_params[:9] == [ + "url", + "token", + "timeout", + "per_page", + "pagination", + "order_by", + "ssl_verify", + "tenant_id", + "kwargs", + ] + + assert list(inspect.signature(imported["InjectManager"]).parameters) == [ + "openaev", + "parent", + ] + assert list(inspect.signature(imported["SignatureType"]).parameters) == [ + "label", + "match_type", + "match_score", + ] + assert list(inspect.signature(imported["SignatureMatch"]).parameters) == [ + "match_type", + "match_score", + ] + + expectation_params = list(inspect.signature(imported["Expectation"]).parameters) + for required in ( + "inject_expectation_id", + "inject_expectation_signatures", + "success_label", + "failure_label", + ): + assert required in expectation_params + + +@then("all public method signatures remain unchanged") +def assert_public_method_signatures(context): + imported = context["imported"] + + def params(fn): + return list(inspect.signature(fn).parameters) + + assert params(imported["InjectManager"].execution_callback) == [ + "self", + "inject_id", + "data", + "kwargs", + ] + assert params(imported["InjectManager"].execution_reception) == [ + "self", + "inject_id", + "data", + "kwargs", + ] + assert params(imported["SignatureType"].make_struct_for_matching) == [ + "self", + "data", + ] + assert params(imported["Expectation"].update) == [ + "self", + "success", + "sender_id", + "metadata", + ] + assert params(imported["Expectation"].match_alert) == [ + "self", + "relevant_signature_types", + "alert_data", + ] + assert params(imported["Expectation"].match_fuzzy) == [ + "tested", + "reference", + "threshold", + ] + assert params(imported["Expectation"].match_simple) == [ + "tested", + "reference", + ] diff --git a/test/signatures/test_signature_manager_post_execution.py b/test/signatures/test_signature_manager_post_execution.py new file mode 100644 index 0000000..cabaf14 --- /dev/null +++ b/test/signatures/test_signature_manager_post_execution.py @@ -0,0 +1,234 @@ +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from pytest_bdd import given, parsers, scenario, then, when + +from pyoaev.signatures.signature_manager import SignatureManager + + +@scenario( + "features/signature_manager_post_execution.feature", + "Successful execution merges end_time and execution_status into pre-execution fields", +) +def test_successful_execution_merges_post_execution_fields(): + pass + + +@scenario( + "constraints/signature_manager_post_execution_constraints.feature", + "Tool crash sets execution_status to failed and uses crash timestamp as end_time", +) +def test_tool_crash_sets_failed_status_and_crash_timestamp_end_time(): + pass + + +@scenario( + "constraints/signature_manager_post_execution_constraints.feature", + "Timeout sets execution_status to timeout and includes available partial results", +) +def test_timeout_sets_timeout_status_and_includes_partial_results(): + pass + + +@scenario( + "constraints/signature_manager_post_execution_constraints.feature", + "Timeout with no partial results still sets execution_status to timeout", +) +def test_timeout_without_partial_results_still_sets_timeout_status(): + pass + + +@scenario( + "features/signature_manager_post_execution.feature", + "Multi-target pre-signatures merge into a list of post-signatures", +) +def test_multi_target_pre_signatures_merge_into_a_list_of_post_signatures(): + pass + + +@pytest.fixture +def context(): + return {} + + +def _parse_iso8601_utc(value: str) -> datetime: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + + +@given( + "a SignatureManager initialised with constructor SignatureManager(client, logger)" +) +def signature_manager(context): + context["signature_manager"] = SignatureManager(MagicMock(), MagicMock()) + + +@given( + "a pre_signatures dict containing:", + target_fixture="pre_signatures", +) +def pre_signatures(): + return { + "source_ipv4": "172.17.0.2", + "target_ipv4": "10.0.0.1", + "target_hostname": "host-a.internal", + "start_time": "2024-06-26T06:00:00Z", + } + + +@given( + "a tool_output indicating successful completion with no errors and no timeout", + target_fixture="tool_output", +) +def successful_tool_output(): + return {"status": "success"} + + +@given( + 'a tool_output containing error_info with exit_code=1 and crash_timestamp="2024-06-26T06:05:00Z"', + target_fixture="tool_output", +) +def crashed_tool_output(): + return { + "error_info": { + "exit_code": 1, + "crash_timestamp": "2024-06-26T06:05:00Z", + } + } + + +@given( + 'a tool_output containing timeout_info with partial_results=["result-A", "result-B"]', + target_fixture="tool_output", +) +def timeout_tool_output_with_partial_results(): + return {"timeout_info": {"partial_results": ["result-A", "result-B"]}} + + +@given( + "a tool_output containing timeout_info with no partial results available", + target_fixture="tool_output", +) +def timeout_tool_output_with_no_partial_results(): + return {"timeout_info": {"partial_results": []}} + + +@when( + "I call compile_post_execution_signatures with the pre_signatures dict and tool_output" +) +def compile_post_execution_signatures(context, pre_signatures, tool_output): + context["result"] = context["signature_manager"].compile_post_execution_signatures( + pre_signatures, tool_output + ) + + +@then("the returned dict contains every key-value pair from pre_signatures unchanged") +@then( + "all pre-execution fields from pre_signatures are present and unchanged in the returned dict" +) +def pre_signatures_unchanged(context, pre_signatures): + result = context["result"] + for key, value in pre_signatures.items(): + assert key in result + assert result[key] == value + + +@then("the returned dict contains end_time as a UTC ISO 8601 string") +def result_contains_iso8601_end_time(context): + end_time = context["result"]["end_time"] + assert isinstance(end_time, str) + _parse_iso8601_utc(end_time) + + +@then( + parsers.parse( + 'end_time is chronologically greater than or equal to start_time "{start_time}"' + ) +) +def end_time_at_or_after_start_time(context, start_time): + end_time_dt = _parse_iso8601_utc(context["result"]["end_time"]) + start_time_dt = _parse_iso8601_utc(start_time) + assert end_time_dt >= start_time_dt + + +@then(parsers.parse('the returned dict contains execution_status equal to "{status}"')) +@then(parsers.parse('execution_status equals "{status}"')) +def execution_status_equals(context, status): + assert context["result"]["execution_status"] == status + + +@then(parsers.parse('end_time equals "{expected_end_time}"')) +def end_time_equals(context, expected_end_time): + assert context["result"]["end_time"] == expected_end_time + + +@then( + 'the returned dict contains the partial results ["result-A", "result-B"] from timeout_info' +) +def contains_timeout_partial_results(context): + assert context["result"]["partial_results"] == ["result-A", "result-B"] + + +# -------------------------------------------------- +# Multi-target post-execution scenario +# -------------------------------------------------- + + +@given( + "the pre_signatures is replaced by a list of 3 dicts each with a distinct target_ipv4", + target_fixture="pre_signatures", +) +def pre_signatures_multi_target_list(): + return [ + { + "source_ipv4": "172.17.0.2", + "target_ipv4": "10.0.0.1", + "start_time": "2024-06-26T06:00:00Z", + }, + { + "source_ipv4": "172.17.0.2", + "target_ipv4": "10.0.0.2", + "start_time": "2024-06-26T06:00:00Z", + }, + { + "source_ipv4": "172.17.0.2", + "target_ipv4": "10.0.0.3", + "start_time": "2024-06-26T06:00:00Z", + }, + ] + + +@then("the returned value is a list of exactly 3 dicts") +def result_is_list_of_three_dicts(context): + result = context["result"] + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(item, dict) for item in result) + + +@then( + parsers.parse( + 'every dict in the returned list contains execution_status equal to "{status}"' + ) +) +def every_dict_has_execution_status(context, status): + for item in context["result"]: + assert item["execution_status"] == status + + +@then("every dict in the returned list contains end_time as a UTC ISO 8601 string") +def every_dict_has_iso8601_end_time(context): + for item in context["result"]: + assert isinstance(item["end_time"], str) + _parse_iso8601_utc(item["end_time"]) + + +@then( + "every dict in the returned list preserves its original target_ipv4 and source_ipv4 fields" +) +def every_dict_preserves_pre_execution_fields(context, pre_signatures): + result = context["result"] + assert len(result) == len(pre_signatures) + for original, merged in zip(pre_signatures, result): + assert merged["target_ipv4"] == original["target_ipv4"] + assert merged["source_ipv4"] == original["source_ipv4"] diff --git a/test/signatures/test_signature_manager_pre_execution.py b/test/signatures/test_signature_manager_pre_execution.py new file mode 100644 index 0000000..f7581dc --- /dev/null +++ b/test/signatures/test_signature_manager_pre_execution.py @@ -0,0 +1,485 @@ +import ipaddress +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest +from pytest_bdd import given, parsers, scenario, then, when + +from pyoaev.signatures.models import ( + CloudInjectorConfig, + ExternalInjectorConfig, + NetworkInjectorConfig, + build_network_configs, +) +from pyoaev.signatures.signature_manager import SignatureManager + +# -------------------------------------------------- +# SCENARIOS +# -------------------------------------------------- + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Network category returns required IP and timing fields and no cloud or query fields", +) +def test_network_category_required_fields(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Network hostname target returns hostname and no target_ipv4", +) +def test_network_hostname_target(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Cloud category returns required cloud identity fields and no IP fields", +) +def test_cloud_category_required_fields(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "External category returns scan target fields and no source IP", +) +def test_external_category_fields(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Network multi-target returns one dict per target with a shared source IP", +) +def test_network_multi_target(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "All network multi-target dicts share the same source_ipv4", +) +def test_network_multi_target_shared_source(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Cloud multi-region returns one dict per region with a shared account ID", +) +def test_cloud_multi_region(): + pass + + +@scenario( + "features/signature_manager_pre_execution.feature", + "Builder classifies a mixed list of targets into typed configs", +) +def test_builder_classifies_mixed_targets(): + pass + + +@scenario( + "constraints/signature_manager_pre_execution_constraints.feature", + "start_time is generated at method-call time not at class instantiation time", +) +def test_start_time_generated_at_call_time(): + pass + + +# -------------------------------------------------- +# FIXTURE CONTEXT +# -------------------------------------------------- + + +@pytest.fixture +def context(): + return {} + + +# -------------------------------------------------- +# HELPERS +# -------------------------------------------------- + + +def parse_utc_iso8601(value): + parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) + return parsed.astimezone(timezone.utc) + + +# -------------------------------------------------- +# GIVEN +# -------------------------------------------------- + + +@given( + "a SignatureManager initialised with constructor SignatureManager(client, logger)", + target_fixture="signature_manager", +) +def signature_manager(): + return SignatureManager(client=MagicMock(), logger=None) + + +@given( + parsers.parse('a NetworkInjectorConfig with target_ipv4="{target_ipv4}"'), + target_fixture="config", +) +def network_config_ipv4(target_ipv4): + return NetworkInjectorConfig(target_ipv4=target_ipv4) + + +@given( + parsers.parse('a NetworkInjectorConfig with target_hostname="{target_hostname}"'), + target_fixture="config", +) +def network_config_hostname(target_hostname): + return NetworkInjectorConfig(target_hostname=target_hostname) + + +@given( + parsers.parse( + 'a CloudInjectorConfig with cloud_provider="{cloud_provider}", ' + 'cloud_account_id="{cloud_account_id}", cloud_region="{cloud_region}", ' + 'and target_service="{target_service}"' + ), + target_fixture="config", +) +def cloud_config_single( + cloud_provider, + cloud_account_id, + cloud_region, + target_service, +): + return CloudInjectorConfig( + cloud_provider=cloud_provider, + cloud_account_id=cloud_account_id, + cloud_region=cloud_region, + target_service=target_service, + ) + + +@given( + parsers.parse( + 'an ExternalInjectorConfig with target_ipv4="{target_ipv4}" and query="{query}"' + ), + target_fixture="config", +) +def external_config_single(target_ipv4, query): + return ExternalInjectorConfig(target_ipv4=target_ipv4, query=query) + + +@given( + parsers.parse( + "a list of 3 NetworkInjectorConfig with target_ipv4 " + '"{ip_1}", "{ip_2}", "{ip_3}"' + ), + target_fixture="config", +) +def network_config_list(ip_1, ip_2, ip_3): + return [ + NetworkInjectorConfig(target_ipv4=ip_1), + NetworkInjectorConfig(target_ipv4=ip_2), + NetworkInjectorConfig(target_ipv4=ip_3), + ] + + +@given( + "a list of 3 NetworkInjectorConfig built from default IPv4 targets", + target_fixture="config", +) +def network_config_list_default(): + return [ + NetworkInjectorConfig(target_ipv4="10.0.0.1"), + NetworkInjectorConfig(target_ipv4="10.0.0.2"), + NetworkInjectorConfig(target_ipv4="10.0.0.3"), + ] + + +@given( + parsers.parse( + 'a list of 3 CloudInjectorConfig with cloud_account_id="{cloud_account_id}" ' + 'and regions "{region_1}", "{region_2}", "{region_3}"' + ), + target_fixture="config", +) +def cloud_config_list(cloud_account_id, region_1, region_2, region_3): + return [ + CloudInjectorConfig( + cloud_provider="aws", + cloud_account_id=cloud_account_id, + cloud_region=region, + ) + for region in (region_1, region_2, region_3) + ] + + +@given( + parsers.parse('a raw mixed target list "{value_1}", "{value_2}", "{value_3}"'), + target_fixture="raw_targets", +) +def raw_mixed_target_list(value_1, value_2, value_3): + return [value_1, value_2, value_3] + + +@given( + "the running container has a resolvable IPv4 address", + target_fixture="source_ipv4", +) +def resolvable_container_ipv4(request): + patcher = patch( + "pyoaev.signatures.signature_manager.SignatureManager.resolve_container_ip", + return_value="172.17.0.2", + ) + patcher.start() + request.addfinalizer(patcher.stop) + return "172.17.0.2" + + +@given( + parsers.parse( + 'the running container has a resolvable IPv4 address "{source_ipv4}"' + ), + target_fixture="source_ipv4", +) +def resolvable_container_ipv4_explicit(request, source_ipv4): + patcher = patch( + "pyoaev.signatures.signature_manager.SignatureManager.resolve_container_ip", + return_value=source_ipv4, + ) + patcher.start() + request.addfinalizer(patcher.stop) + return source_ipv4 + + +@given( + "a SignatureManager that was instantiated at timestamp T0", + target_fixture="signature_manager", +) +def signature_manager_at_t0(context): + t0 = datetime(2024, 6, 26, 6, 6, 1, tzinfo=timezone.utc) + context["t0"] = t0 + manager = SignatureManager(client=MagicMock(), logger=None) + manager._test_t0 = t0 + return manager + + +@given( + "5 seconds elapse after instantiation", + target_fixture="t1", +) +def elapsed_5_seconds(context): + t1 = context["t0"] + timedelta(seconds=5) + context["t1"] = t1 + return t1 + + +# -------------------------------------------------- +# WHEN +# -------------------------------------------------- + + +@when( + "I call compile_pre_execution_signatures with the config", + target_fixture="result", +) +def call_compile_with_config(signature_manager, config): + return signature_manager.compile_pre_execution_signatures(config=config) + + +@when( + "I call compile_pre_execution_signatures with the config list", + target_fixture="result", +) +def call_compile_with_config_list(signature_manager, config): + return signature_manager.compile_pre_execution_signatures(config=config) + + +@when( + "I call compile_pre_execution_signatures with the config at timestamp T1", + target_fixture="result", +) +def call_compile_at_t1(signature_manager, config, t1): + with patch.object(signature_manager, "_utcnow", return_value=t1): + return signature_manager.compile_pre_execution_signatures(config=config) + + +@when( + "I build network configs from the raw list", + target_fixture="built_configs", +) +def build_configs_from_raw(raw_targets): + return build_network_configs(raw_targets) + + +# -------------------------------------------------- +# THEN +# -------------------------------------------------- + + +@then("the returned dict contains source_ipv4 as a non-empty valid IPv4 address string") +def source_ipv4_is_valid(result): + source_ipv4 = result["source_ipv4"] + assert source_ipv4 + ipaddress.IPv4Address(source_ipv4) + + +@then("the returned dict contains start_time as a UTC ISO 8601 string") +def start_time_is_utc_iso8601(result): + start_time = result["start_time"] + parsed = parse_utc_iso8601(start_time) + assert parsed.tzinfo is not None + + +@then(parsers.parse('the returned dict contains target_ipv4 equal to "{value}"')) +def returned_dict_target_ipv4(result, value): + assert result["target_ipv4"] == value + + +@then(parsers.parse('the returned dict contains target_hostname equal to "{value}"')) +def returned_dict_target_hostname(result, value): + assert result["target_hostname"] == value + + +@then(parsers.parse('the returned dict contains cloud_provider equal to "{value}"')) +def returned_dict_cloud_provider(result, value): + assert result["cloud_provider"] == value + + +@then(parsers.parse('the returned dict contains cloud_account_id equal to "{value}"')) +def returned_dict_cloud_account_id(result, value): + assert result["cloud_account_id"] == value + + +@then(parsers.parse('the returned dict contains cloud_region equal to "{value}"')) +def returned_dict_cloud_region(result, value): + assert result["cloud_region"] == value + + +@then(parsers.parse('the returned dict contains target_service equal to "{value}"')) +def returned_dict_target_service(result, value): + assert result["target_service"] == value + + +@then(parsers.parse('the returned dict contains query equal to "{value}"')) +def returned_dict_query(result, value): + assert result["query"] == value + + +@then(parsers.parse("the returned dict does not contain {field}")) +def returned_dict_does_not_contain_field(result, field): + assert field not in result + + +@then("the return value is a list of exactly 3 dicts") +def return_value_is_list_of_three_dicts(result): + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(item, dict) for item in result) + + +@then(parsers.parse("the return value is a list of {count:d} dicts")) +def return_value_is_list_of_n_dicts(result, count): + assert isinstance(result, list) + assert len(result) == count + assert all(isinstance(item, dict) for item in result) + + +@then( + parsers.parse( + 'the dict at position {index:d} contains target_ipv4 equal to "{target_ip}"' + ) +) +def list_dict_contains_target_ipv4_at_position(result, index, target_ip): + assert result[index]["target_ipv4"] == target_ip + + +@then( + parsers.parse( + 'the dict at position {index:d} contains source_ipv4 equal to "{source_ipv4}"' + ) +) +def list_dict_contains_source_ipv4_at_position( + result, + index, + source_ipv4, +): + assert result[index]["source_ipv4"] == source_ipv4 + + +@then( + parsers.parse( + 'the dict at position {index:d} contains cloud_region equal to "{region}"' + ) +) +def list_dict_contains_cloud_region_at_position(result, index, region): + assert result[index]["cloud_region"] == region + + +@then( + parsers.parse( + 'the dict at position {index:d} contains cloud_account_id equal to "{account_id}"' + ) +) +def list_dict_contains_cloud_account_id_at_position(result, index, account_id): + assert result[index]["cloud_account_id"] == account_id + + +@then("all 3 dicts contain the same source_ipv4 value") +def all_dicts_share_same_source_ipv4(result): + assert isinstance(result, list) + assert len(result) == 3 + source_values = {item["source_ipv4"] for item in result} + assert len(source_values) == 1 + ipaddress.IPv4Address(next(iter(source_values))) + + +@then("the start_time in the returned dict equals T1 within 1 second tolerance") +def start_time_equals_t1_with_tolerance(result, t1): + start_time = parse_utc_iso8601(result["start_time"]) + delta_seconds = abs((start_time - t1).total_seconds()) + assert delta_seconds <= 1 + + +@then("start_time does not equal T0") +def start_time_not_equal_t0(result, signature_manager): + start_time = parse_utc_iso8601(result["start_time"]) + assert start_time != signature_manager._test_t0 + + +@then(parsers.parse("the builder returns {count:d} NetworkInjectorConfig")) +def builder_returns_n_configs(built_configs, count): + assert isinstance(built_configs, list) + assert len(built_configs) == count + assert all(isinstance(c, NetworkInjectorConfig) for c in built_configs) + + +@then( + parsers.parse('the config at position {index:d} has target_ipv4 equal to "{value}"') +) +def config_has_target_ipv4(built_configs, index, value): + assert built_configs[index].target_ipv4 == value + assert built_configs[index].target_ipv6 is None + assert built_configs[index].target_hostname is None + + +@then( + parsers.parse('the config at position {index:d} has target_ipv6 equal to "{value}"') +) +def config_has_target_ipv6(built_configs, index, value): + assert built_configs[index].target_ipv6 == value + assert built_configs[index].target_ipv4 is None + assert built_configs[index].target_hostname is None + + +@then( + parsers.parse( + 'the config at position {index:d} has target_hostname equal to "{value}"' + ) +) +def config_has_target_hostname(built_configs, index, value): + assert built_configs[index].target_hostname == value + assert built_configs[index].target_ipv4 is None + assert built_configs[index].target_ipv6 is None diff --git a/test/signatures/test_signature_manager_transmission.py b/test/signatures/test_signature_manager_transmission.py new file mode 100644 index 0000000..6519bd6 --- /dev/null +++ b/test/signatures/test_signature_manager_transmission.py @@ -0,0 +1,652 @@ +import ipaddress +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, call + +import pytest +from pytest_bdd import given, parsers, scenario, then, when + +from pyoaev.apis.signature import SignatureApiManager +from pyoaev.exceptions import OpenAEVUpdateError, SignatureTransmissionError +from pyoaev.signatures.signature_manager import SignatureManager + + +@scenario( + "features/signature_manager_transmission.feature", + "HTTP 2xx response is treated as successful transmission", +) +def test_http_2xx_response_is_treated_as_successful_transmission(): + pass + + +@scenario( + "features/signature_manager_transmission.feature", + "send_signatures posts to the inject callback with the agreed nested schema", +) +def test_send_signatures_posts_with_agreed_nested_schema(): + pass + + +@scenario( + "constraints/signature_manager_transmission_constraints.feature", + "Payload exceeding MAX_PAYLOAD_SIZE is auto-chunked with chunk metadata", +) +def test_payload_exceeding_max_payload_size_is_split_into_sequential_chunks(): + pass + + +@scenario( + "constraints/signature_manager_transmission_constraints.feature", + "HTTP 5xx response triggers exponential backoff retry for up to 3 additional attempts", +) +def test_http_5xx_response_triggers_exponential_backoff_retry(): + pass + + +@scenario( + "constraints/signature_manager_transmission_constraints.feature", + "HTTP 4xx response raises an exception immediately with no retries and no sleep", +) +def test_http_4xx_response_raises_exception_immediately(): + pass + + +@scenario( + "features/signature_manager_transmission.feature", + "resolve_container_ip returns a valid IPv4 in each supported execution environment", +) +def test_resolve_container_ip_returns_valid_ipv4(): + pass + + +@scenario( + "constraints/signature_manager_transmission_constraints.feature", + "resolve_container_ip returns unknown and emits exactly one warning when all strategies fail", +) +def test_resolve_container_ip_returns_unknown_when_all_strategies_fail(): + pass + + +@scenario( + "features/signature_manager_transmission.feature", + "Payload schema groups signature values by expectation_type within each target", +) +def test_payload_schema_groups_signature_values_by_expectation_type(): + pass + + +@pytest.fixture +def context(): + return {} + + +_CANONICAL_SIGNATURE_TARGET = { + "agent": "b044fbc7-f277-4c8c-aeae-5c5497598c51", + "asset": "asset-host-a", + "asset_group": "asset-group-internal", +} + + +def _build_signature_payload( + signature_value="203.0.113.5", + expectation_types=None, +): + if expectation_types is None: + expectation_types = ["DETECTION"] + return { + "targets": [ + { + "signature_target": dict(_CANONICAL_SIGNATURE_TARGET), + "signature_values": [ + { + "expectation_type": expectation_type, + "values": [ + { + "signature_type": "public_ip", + "signature_value": ( + signature_value + if expectation_type == "DETECTION" + else "198.51.100.10" + ), + } + ], + } + for expectation_type in expectation_types + ], + } + ] + } + + +@given( + "a SignatureManager initialised with constructor SignatureManager(client, logger)" +) +def signature_manager(context, monkeypatch): + logger = MagicMock() + mock_client = MagicMock() + sleep_mock = MagicMock() + captured_calls = [] + + def _http_post(*args, **kwargs): + path = kwargs.get("path", args[0] if args else None) + post_data = kwargs.get("post_data", args[1] if len(args) > 1 else None) + captured_calls.append( + { + "path": path, + "post_data": post_data, + } + ) + status_plan = context.get("status_plan", [200]) + status_code = status_plan[min(len(captured_calls) - 1, len(status_plan) - 1)] + if status_code >= 400: + raise OpenAEVUpdateError( + f"HTTP {status_code}", + response_code=status_code, + response_body=context.get("error_body", "").encode(), + ) + return SimpleNamespace(status_code=status_code) + + mock_client.http_post.side_effect = _http_post + + # Wire up the real SignatureApi so delegation works + sig_api = SignatureApiManager(mock_client) + mock_client.signature = sig_api + + monkeypatch.setattr( + "pyoaev.apis.signature.time.sleep", + sleep_mock, + ) + + context["logger"] = logger + context["mock_client"] = mock_client + context["sleep_mock"] = sleep_mock + context["captured_calls"] = captured_calls + context["status_plan"] = [200] + context["error_body"] = "" + context["inject_id"] = "inject-abc-001" + context["phase"] = "execution_complete" + context["signatures"] = _build_signature_payload() + context["signature_manager"] = SignatureManager(mock_client, logger=logger) + + +@given(parsers.parse('a compiled post-execution payload for inject_id "{inject_id}"')) +def compiled_post_execution_payload(context, inject_id): + context["inject_id"] = inject_id + context["signatures"] = _build_signature_payload() + + +@given( + parsers.parse( + 'a compiled payload with 1 target, expectation_type "{expectation_type}", signature_type "{signature_type}", signature_value "{signature_value}"' + ) +) +def compiled_payload_single_target( + context, + expectation_type, + signature_type, + signature_value, +): + context["signatures"] = { + "targets": [ + { + "signature_target": dict(_CANONICAL_SIGNATURE_TARGET), + "signature_values": [ + { + "expectation_type": expectation_type, + "values": [ + { + "signature_type": signature_type, + "signature_value": signature_value, + } + ], + } + ], + } + ] + } + + +@given( + "a compiled payload whose serialised size exceeds MAX_PAYLOAD_SIZE by at least a factor of 2" +) +def compiled_large_payload(context): + context["signature_manager"] = SignatureManager( + context["mock_client"], + logger=context["logger"], + max_payload_size=700, + ) + context["signatures"] = { + "targets": [ + { + "signature_target": { + "agent": f"agent-{index:08d}-0000-0000-0000-000000000000", + "asset": f"asset-{index}", + "asset_group": "asset-group-bulk", + }, + "signature_values": [ + { + "expectation_type": "DETECTION", + "values": [ + { + "signature_type": "public_ip", + "signature_value": "203.0.113.123", + }, + { + "signature_type": "hostname", + "signature_value": f"host-{index}." + ("a" * 140), + }, + ], + } + ], + } + for index in range(6) + ] + } + + +@given(parsers.parse("the backend responds with HTTP {status_code:d}")) +def backend_responds_with_http_status(context, status_code): + context["status_plan"] = [status_code] + context["error_body"] = "" + + +@given("the backend responds with HTTP 503 on every attempt") +def backend_responds_with_http_503_every_time(context): + context["status_plan"] = [503, 503, 503, 503] + + +@given( + parsers.parse("the backend responds with HTTP {status_code:d} and body '{body}'") +) +def backend_responds_with_http_status_and_body(context, status_code, body): + context["status_plan"] = [status_code] + context["error_body"] = body + + +@given(parsers.parse('a SignatureManager running in a "{environment}" environment')) +def signature_manager_environment(context, monkeypatch, environment): + if environment == "Docker": + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.socket.gethostbyname", + lambda _: "172.17.0.2", + ) + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.subprocess.run", + lambda *args, **kwargs: SimpleNamespace( + returncode=0, + stdout="172.17.0.2\n", + ), + ) + elif environment == "Kubernetes": + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.socket.gethostbyname", + lambda _: (_ for _ in ()).throw(OSError("socket fail")), + ) + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.subprocess.run", + lambda *args, **kwargs: SimpleNamespace( + returncode=0, + stdout="10.244.0.8\n", + ), + ) + else: + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.socket.gethostbyname", + lambda _: "192.0.2.20", + ) + + +@given("all IP resolution strategies are mocked to fail") +def all_ip_resolution_strategies_fail(context, monkeypatch): + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.socket.gethostbyname", + lambda _: (_ for _ in ()).throw(OSError("socket fail")), + ) + monkeypatch.setattr( + "pyoaev.signatures.signature_manager.subprocess.run", + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("command fail")), + ) + + +@given( + parsers.parse( + 'a compiled payload for 1 target with signatures of expectation_type "{expectation_a}" and expectation_type "{expectation_b}"' + ) +) +def compiled_payload_grouped_by_expectation( + context, + expectation_a, + expectation_b, +): + context["signatures"] = { + "targets": [ + { + "signature_target": dict(_CANONICAL_SIGNATURE_TARGET), + "signature_values": [ + { + "expectation_type": expectation_a, + "signature_type": "public_ip", + "signature_value": "203.0.113.5", + }, + { + "expectation_type": expectation_b, + "signature_type": "public_ip", + "signature_value": "198.51.100.10", + }, + { + "expectation_type": expectation_a, + "signature_type": "hostname", + "signature_value": "host-a.internal", + }, + ], + } + ] + } + + +@when( + parsers.parse( + 'I call send_signatures for inject_id "{inject_id}" with phase "{phase}"' + ) +) +def call_send_signatures(context, inject_id, phase): + context["inject_id"] = inject_id + context["phase"] = phase + context["send_exception"] = None + try: + context["signature_manager"].send_signatures( + inject_id, + phase, + context["signatures"], + ) + except Exception as exc: + context["send_exception"] = exc + + +@when("I call resolve_container_ip") +def call_resolve_container_ip(context): + context["resolve_exception"] = None + try: + context["resolved_ip"] = context["signature_manager"].resolve_container_ip() + except Exception as exc: + context["resolve_exception"] = exc + + +@then("send_signatures completes without raising an exception") +def send_signatures_completes_without_exception(context): + assert context["send_exception"] is None + + +@then( + parsers.parse( + "a POST request is sent to /injects/{inject_id}/callback", + ) +) +def assert_post_request_sent_to_callback(context, inject_id): + assert context["captured_calls"] + assert context["captured_calls"][-1]["path"] == f"/injects/{inject_id}/callback" + + +@then("the POST request body contains signatures.targets as a list") +def assert_targets_is_list(context): + body = context["captured_calls"][-1]["post_data"] + assert isinstance(body["expectation_signature"]["targets"], list) + + +@then( + parsers.parse( + 'signatures.targets[0].signature_values[0].expectation_type equals "{expected_value}"' + ) +) +def assert_expectation_type(context, expected_value): + body = context["captured_calls"][-1]["post_data"] + assert body["expectation_signature"]["targets"][0]["signature_values"][0][ + "expectation_type" + ] == (expected_value) + + +@then( + parsers.parse( + 'signatures.targets[0].signature_values[0].values[0].signature_type equals "{expected_value}"' + ) +) +def assert_signature_type(context, expected_value): + body = context["captured_calls"][-1]["post_data"] + assert ( + body["expectation_signature"]["targets"][0]["signature_values"][0]["values"][0][ + "signature_type" + ] + == expected_value + ) + + +@then( + parsers.parse( + 'signatures.targets[0].signature_values[0].values[0].signature_value equals "{expected_value}"' + ) +) +def assert_signature_value(context, expected_value): + body = context["captured_calls"][-1]["post_data"] + assert ( + body["expectation_signature"]["targets"][0]["signature_values"][0]["values"][0][ + "signature_value" + ] + == expected_value + ) + + +@then("signatures.targets[0] contains a signature_target key") +def assert_signature_target_key(context): + body = context["captured_calls"][-1]["post_data"] + assert "signature_target" in body["expectation_signature"]["targets"][0] + + +@then( + parsers.parse( + "the payload is sent as multiple sequential POST requests to /injects/{inject_id}/callback", + ) +) +def assert_payload_sent_as_multiple_chunks(context, inject_id): + assert context["send_exception"] is None + assert len(context["captured_calls"]) > 1 + assert all( + call_item["path"] == f"/injects/{inject_id}/callback" + for call_item in context["captured_calls"] + ) + + +@then("each POST request body contains chunk_index as a 0-based integer") +def assert_chunk_index_present(context): + for index, call_item in enumerate(context["captured_calls"]): + post_data = call_item["post_data"] + assert isinstance(post_data["chunk_index"], int) + assert post_data["chunk_index"] == index + + +@then( + "each POST request body contains total_chunks as a positive integer matching the total number of chunks sent" +) +def assert_total_chunks_present(context): + total_chunks = len(context["captured_calls"]) + for call_item in context["captured_calls"]: + post_data = call_item["post_data"] + assert isinstance(post_data["total_chunks"], int) + assert post_data["total_chunks"] > 0 + assert post_data["total_chunks"] == total_chunks + + +@then( + 'each POST request body contains only "signatures", "chunk_index" and "total_chunks" at the top level' +) +def assert_chunked_envelope_is_strict(context): + expected_keys = {"expectation_signature", "chunk_index", "total_chunks", "phase"} + for call_item in context["captured_calls"]: + post_data = call_item["post_data"] + assert set(post_data.keys()) == expected_keys, ( + f"Chunked envelope must contain exactly {expected_keys}, " + f"got {set(post_data.keys())}" + ) + + +@then("the union of targets across all POST requests equals the original target set") +def assert_targets_union_matches_original(context): + original_targets = context["signatures"]["targets"] + sent_targets = [ + target + for call_item in context["captured_calls"] + for target in call_item["post_data"]["expectation_signature"]["targets"] + ] + assert len(sent_targets) == len(original_targets), ( + f"Expected {len(original_targets)} targets across all chunks, " + f"got {len(sent_targets)}" + ) + for original, sent in zip(original_targets, sent_targets): + assert sent["signature_target"] == original["signature_target"] + + +@then("no individual POST request body exceeds MAX_PAYLOAD_SIZE bytes") +def assert_payload_size_per_chunk(context): + max_payload_size = context["signature_manager"].max_payload_size + for call_item in context["captured_calls"]: + post_data = call_item["post_data"] + payload_size = len(json.dumps(post_data).encode()) + assert payload_size <= max_payload_size + + +@then( + parsers.parse( + "send_signatures sends a total of {total_requests:d} POST requests to /injects/{inject_id}/callback" + ) +) +def assert_total_post_requests(context, total_requests, inject_id): + assert len(context["captured_calls"]) == total_requests + assert all( + call_item["path"] == f"/injects/{inject_id}/callback" + for call_item in context["captured_calls"] + ) + + +@then( + "a WARNING log message containing the retry attempt number is emitted before each of the 3 retry attempts" +) +def assert_warning_logs_for_retries(context): + assert context["logger"].warning.call_count == 3 + warning_messages = [ + " ".join(str(arg) for arg in warning_call.args) + for warning_call in context["logger"].warning.call_args_list + ] + assert any("1" in message for message in warning_messages) + assert any("2" in message for message in warning_messages) + assert any("3" in message for message in warning_messages) + + +@then(parsers.parse("the wait before attempt {attempt:d} is {seconds:d} second")) +@then(parsers.parse("the wait before attempt {attempt:d} is {seconds:d} seconds")) +def assert_wait_before_attempt(context, attempt, seconds): + assert context["sleep_mock"].call_args_list[attempt - 2] == call(seconds) + + +@then("a SignatureTransmissionError is raised after all retries are exhausted") +def assert_signature_transmission_error_after_retries(context): + assert isinstance(context["send_exception"], SignatureTransmissionError) + + +@then( + parsers.parse( + "only {request_count:d} POST request is sent to /injects/{inject_id}/callback" + ) +) +def assert_single_post_request(context, request_count, inject_id): + assert len(context["captured_calls"]) == request_count + assert context["captured_calls"][0]["path"] == f"/injects/{inject_id}/callback" + + +@then( + parsers.parse( + "an ERROR log message containing status code {status_code:d} and the response body is emitted" + ) +) +def assert_error_log_contains_status_and_body(context, status_code): + assert context["logger"].error.call_count >= 1 + message_text = " ".join( + str(arg) + for call_args in context["logger"].error.call_args_list + for arg in call_args.args + ) + assert str(status_code) in message_text + assert context["error_body"] in message_text + + +@then("an exception is raised immediately") +def assert_exception_raised_immediately(context): + assert context["send_exception"] is not None + + +@then("no sleep or wait occurs before the exception is raised") +def assert_no_sleep_occurs(context): + assert context["sleep_mock"].call_count == 0 + + +@then("the returned value is a non-empty valid IPv4 address string") +def assert_returned_value_valid_ipv4(context): + assert context["resolve_exception"] is None + resolved_ip = context["resolved_ip"] + assert isinstance(resolved_ip, str) + assert resolved_ip.strip() != "" + assert ipaddress.ip_address(resolved_ip).version == 4 + + +@then(parsers.parse('the returned value is the string "{expected_value}"')) +def assert_returned_value_matches(context, expected_value): + assert context["resolved_ip"] == expected_value + + +@then(parsers.parse("exactly {count:d} WARNING log message is emitted")) +def assert_warning_count(context, count): + assert context["logger"].warning.call_count == count + + +@then("no exception propagates from resolve_container_ip") +def assert_no_exception_from_resolve_container_ip(context): + assert context["resolve_exception"] is None + + +@then( + "the POST request body nests signature values under separate expectation_type entries within signatures.targets[0].signature_values" +) +def assert_signature_values_nested_by_expectation_type(context): + body = context["captured_calls"][-1]["post_data"] + entries = body["expectation_signature"]["targets"][0]["signature_values"] + expectation_types = {entry["expectation_type"] for entry in entries} + assert expectation_types == {"DETECTION", "PREVENTION"} + + +@then( + 'the entry with expectation_type "DETECTION" contains only DETECTION signature values' +) +def assert_detection_values_grouped_correctly(context): + body = context["captured_calls"][-1]["post_data"] + entries = body["expectation_signature"]["targets"][0]["signature_values"] + detection_entry = next( + entry for entry in entries if entry["expectation_type"] == "DETECTION" + ) + detection_values = {value["signature_value"] for value in detection_entry["values"]} + assert detection_values == {"203.0.113.5", "host-a.internal"} + assert "198.51.100.10" not in detection_values + + +@then( + 'the entry with expectation_type "PREVENTION" contains only PREVENTION signature values' +) +def assert_prevention_values_grouped_correctly(context): + body = context["captured_calls"][-1]["post_data"] + entries = body["expectation_signature"]["targets"][0]["signature_values"] + prevention_entry = next( + entry for entry in entries if entry["expectation_type"] == "PREVENTION" + ) + prevention_values = { + value["signature_value"] for value in prevention_entry["values"] + } + assert prevention_values == {"198.51.100.10"} + assert "203.0.113.5" not in prevention_values + assert "host-a.internal" not in prevention_values