From 018afad5eacdb9486a8e98b7b33c331c2b5beed1 Mon Sep 17 00:00:00 2001 From: Wolfgang Hotwagner Date: Mon, 27 Nov 2023 20:42:36 +0100 Subject: [PATCH] updated to pydantic2 --- .pre-commit-config.yaml | 2 +- log.txt | 0 pyproject.toml | 2 +- src/attackmate/schemas.py | 116 +++++++++++++++++++++----------------- 4 files changed, 66 insertions(+), 54 deletions(-) create mode 100644 log.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fec4f00..55a5e59 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: - id: check-ast - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.0 + rev: v1.7.1 hooks: - id: mypy additional_dependencies: [pydantic, types-PyYAML, types-requests, types-paramiko, types-tabulate] diff --git a/log.txt b/log.txt new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 17d66d2..61d110e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ requires-python = ">=3.7" keywords = ["Pentest", "Attack", "Orchestration", "Hacking", "Simulating", "Attackchain"] license = {text = "GPL-3.0"} dependencies = [ - "pydantic ~= 1.10", + "pydantic ~= 2.5", "colorlog", "pymetasploit3", "pyaml", diff --git a/src/attackmate/schemas.py b/src/attackmate/schemas.py index b15ef41..592e14b 100644 --- a/src/attackmate/schemas.py +++ b/src/attackmate/schemas.py @@ -1,9 +1,25 @@ -from typing import List, Literal, Union, Optional, Dict -from pydantic import BaseModel, Field, validator +from typing import Annotated, List, Literal, Union, Optional, Dict +from pydantic import AfterValidator, BeforeValidator, field_validator, BaseModel, ValidationInfo +import re # https://stackoverflow.com/questions/71539448/using-different-pydantic-models-depending-on-the-value-of-fields - VAR_PATTERN = r'^\$[$a-zA-Z0-9_]+$|^[0-9]+$' +pattern = re.compile(VAR_PATTERN) + + +def transform_int_to_str(value) -> str: + return str(value) + + +def check_var_pattern(value: str, info: ValidationInfo) -> str: + global pattern + assert pattern.match(value), f'{info.field_name} must be a variable, integer or numeric string' + return value + + +StringNumber = Annotated[Optional[str | int], + BeforeValidator(transform_int_to_str), + AfterValidator(check_var_pattern)] class BaseCommand(BaseModel): @@ -27,7 +43,8 @@ def list_template_vars(self) -> List[str]: template_vars.append(k) return template_vars - @validator('background') + @field_validator('background') + @classmethod def bg_not_implemented_yet(cls, v): if cls in (MsfSessionCommand, IncludeCommand): raise ValueError('background mode is unsupported for this command') @@ -38,7 +55,7 @@ def bg_not_implemented_yet(cls, v): error_if_not: Optional[str] = None loop_if: Optional[str] = None loop_if_not: Optional[str] = None - loop_count: str = Field(pattern=VAR_PATTERN, default='3') + loop_count: StringNumber = '3' exit_on_error: bool = True save: Optional[str] = None cmd: str @@ -48,8 +65,8 @@ def bg_not_implemented_yet(cls, v): class SleepCommand(BaseCommand): type: Literal['sleep'] - min_sec: str = Field(pattern=VAR_PATTERN, default='0') - seconds: str = Field(pattern=VAR_PATTERN, default='1') + min_sec: StringNumber = '0' + seconds: StringNumber = '1' random: bool = False cmd: str = 'sleep' @@ -74,7 +91,7 @@ class WebServCommand(BaseCommand): type: Literal['webserv'] cmd: str = 'HTTP-GET' local_path: str - port: str = Field(pattern=VAR_PATTERN, default='8000') + port: StringNumber = '8000' address: str = '0.0.0.0' # nosec @@ -90,7 +107,7 @@ class FatherCommand(BaseCommand): hiddenport: str = 'D431' shell_pass: str = 'lobster' install_path: str = '/lib/selinux.so.3' - local_path: Optional[str] + local_path: Optional[str] = None arch: Literal['amd64'] = 'amd64' build_command: str = 'make' @@ -116,38 +133,33 @@ class RegExCommand(BaseCommand): class SSHBase(BaseCommand): - @validator('session') - def session_and_background_unsupported(cls, v, values, **kwargs): - if 'background' in values and values['background']: + @field_validator('session', 'creates_session') + @classmethod + def session_and_background_unsupported(cls, v, info: ValidationInfo) -> str: + if 'background' in info.data and info.data['background']: raise ValueError('background mode combined with session is unsupported for SSH') return v - @validator('creates_session') - def creates_session_and_background_unsupported(cls, v, values, **kwargs): - if 'background' in values and values['background']: - raise ValueError('background mode combined with session is unsupported for SSH') - return v - - hostname: Optional[str] - port: Optional[str] = Field(pattern=VAR_PATTERN, default=None) - username: Optional[str] - password: Optional[str] - passphrase: Optional[str] - key_filename: Optional[str] - creates_session: Optional[str] - session: Optional[str] + hostname: Optional[str] = None + port: StringNumber = None + username: Optional[str] = None + password: Optional[str] = None + passphrase: Optional[str] = None + key_filename: Optional[str] = None + creates_session: Optional[str] = None + session: Optional[str] = None clear_cache: bool = False timeout: float = 60 - jmp_hostname: Optional[str] - jmp_port: Optional[str] = Field(pattern=VAR_PATTERN, default=None) - jmp_username: Optional[str] + jmp_hostname: Optional[str] = None + jmp_port: StringNumber = None + jmp_username: Optional[str] = None class SSHCommand(SSHBase): type: Literal['ssh'] interactive: bool = False validate_prompt: bool = True - command_timeout: str = Field(pattern=VAR_PATTERN, default='15') + command_timeout: StringNumber = '15' prompts: List[str] = ['$ ', '# ', '> '] @@ -156,7 +168,7 @@ class SFTPCommand(SSHBase): cmd: Literal['get', 'put'] remote_path: str local_path: str - mode: Optional[str] + mode: Optional[str] = None class MsfSessionCommand(BaseCommand): @@ -166,7 +178,7 @@ class MsfSessionCommand(BaseCommand): write: bool = False read: bool = False session: str - end_str: Optional[str] + end_str: Optional[str] = None class MsfPayloadCommand(BaseCommand): @@ -179,19 +191,19 @@ class MsfPayloadCommand(BaseCommand): template: Optional[str] = None platform: Optional[str] = None keep_template_working: bool = False - nopsled_size: str = Field(pattern=VAR_PATTERN, default='0') - iter: str = Field(pattern=VAR_PATTERN, default='0') + nopsled_size: StringNumber = '0' + iter: StringNumber = '0' payload_options: Dict[str, str] = {} - local_path: Optional[str] + local_path: Optional[str] = None class MsfModuleCommand(BaseCommand): cmd: str type: Literal['msf-module'] - target: str = Field(pattern=VAR_PATTERN, default='0') - creates_session: Optional[str] - session: Optional[str] - payload: Optional[str] + target: StringNumber = '0' + creates_session: Optional[str] = None + session: Optional[str] = None + payload: Optional[str] = None options: Dict[str, str] = {} payload_options: Dict[str, str] = {} @@ -219,10 +231,10 @@ class HttpClientCommand(BaseCommand): cmd: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'] = 'GET' url: str output_headers: bool = False - headers: Optional[Dict[str, str]] - cookies: Optional[Dict[str, str]] - data: Optional[Dict[str, str]] - local_path: Optional[str] + headers: Optional[Dict[str, str]] = None + cookies: Optional[Dict[str, str]] = None + data: Optional[Dict[str, str]] = None + local_path: Optional[str] = None useragent: str = 'AttackMate' follow: bool = False verify: bool = False @@ -249,16 +261,16 @@ class SliverHttpsListenerCommand(BaseCommand): type: Literal['sliver'] cmd: Literal['start_https_listener'] host: str = '0.0.0.0' # nosec - port: str = Field(pattern=VAR_PATTERN, default='443') + port: StringNumber = '443' domain: str = '' website: str = '' acme: bool = False persistent: bool = False enforce_otp: bool = True randomize_jarm: bool = True - long_poll_timeout: str = Field(pattern=VAR_PATTERN, default='1') - long_poll_jitter: str = Field(pattern=VAR_PATTERN, default='2') - timeout: str = Field(pattern=VAR_PATTERN, default='60') + long_poll_timeout: StringNumber = '1' + long_poll_jitter: StringNumber = '3' + timeout: StringNumber = '60' class SliverGenerateCommand(BaseCommand): @@ -278,9 +290,9 @@ class SliverGenerateCommand(BaseCommand): 'SHARED_LIB', 'SHELLCODE'] = 'EXECUTABLE' name: str - filepath: Optional[str] + filepath: Optional[str] = None IsBeacon: bool = False - BeaconInterval: str = Field(pattern=VAR_PATTERN, default='120') + BeaconInterval: StringNumber = '120' RunAtLoad: bool = False Evasion: bool = False @@ -328,7 +340,7 @@ class SliverSessionNETSTATCommand(SliverSessionCommand): class SliverSessionEXECCommand(SliverSessionCommand): cmd: Literal['execute'] exe: str - args: Optional[List[str]] + args: Optional[List[str]] = None output: bool = True @@ -344,7 +356,7 @@ class SliverSessionLSCommand(SliverSessionCommand): class SliverSessionPROCDUMPCommand(SliverSessionCommand): cmd: Literal['process_dump'] local_path: str - pid: str = Field(pattern=VAR_PATTERN) + pid: StringNumber class SliverSessionRMCommand(SliverSessionCommand): @@ -356,7 +368,7 @@ class SliverSessionRMCommand(SliverSessionCommand): class SliverSessionTERMINATECommand(SliverSessionCommand): cmd: Literal['terminate'] - pid: str = Field(pattern=VAR_PATTERN) + pid: StringNumber force: bool = False