Skip to content

Commit

Permalink
updated to pydantic2
Browse files Browse the repository at this point in the history
  • Loading branch information
Wolfgang Hotwagner committed Nov 27, 2023
1 parent b326cbd commit 018afad
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
- id: check-ast

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.0
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [pydantic, types-PyYAML, types-requests, types-paramiko, types-tabulate]
Expand Down
Empty file added log.txt
Empty file.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ requires-python = ">=3.7"
keywords = ["Pentest", "Attack", "Orchestration", "Hacking", "Simulating", "Attackchain"]
license = {text = "GPL-3.0"}
dependencies = [
"pydantic ~= 1.10",
"pydantic ~= 2.5",
"colorlog",
"pymetasploit3",
"pyaml",
Expand Down
116 changes: 64 additions & 52 deletions src/attackmate/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import List, Literal, Union, Optional, Dict
from pydantic import BaseModel, Field, validator
from typing import Annotated, List, Literal, Union, Optional, Dict
from pydantic import AfterValidator, BeforeValidator, field_validator, BaseModel, ValidationInfo
import re

# https://stackoverflow.com/questions/71539448/using-different-pydantic-models-depending-on-the-value-of-fields

VAR_PATTERN = r'^\$[$a-zA-Z0-9_]+$|^[0-9]+$'
pattern = re.compile(VAR_PATTERN)


def transform_int_to_str(value) -> str:
return str(value)


def check_var_pattern(value: str, info: ValidationInfo) -> str:
global pattern
assert pattern.match(value), f'{info.field_name} must be a variable, integer or numeric string'
return value


StringNumber = Annotated[Optional[str | int],
BeforeValidator(transform_int_to_str),
AfterValidator(check_var_pattern)]


class BaseCommand(BaseModel):
Expand All @@ -27,7 +43,8 @@ def list_template_vars(self) -> List[str]:
template_vars.append(k)
return template_vars

@validator('background')
@field_validator('background')
@classmethod
def bg_not_implemented_yet(cls, v):
if cls in (MsfSessionCommand, IncludeCommand):
raise ValueError('background mode is unsupported for this command')
Expand All @@ -38,7 +55,7 @@ def bg_not_implemented_yet(cls, v):
error_if_not: Optional[str] = None
loop_if: Optional[str] = None
loop_if_not: Optional[str] = None
loop_count: str = Field(pattern=VAR_PATTERN, default='3')
loop_count: StringNumber = '3'
exit_on_error: bool = True
save: Optional[str] = None
cmd: str
Expand All @@ -48,8 +65,8 @@ def bg_not_implemented_yet(cls, v):

class SleepCommand(BaseCommand):
type: Literal['sleep']
min_sec: str = Field(pattern=VAR_PATTERN, default='0')
seconds: str = Field(pattern=VAR_PATTERN, default='1')
min_sec: StringNumber = '0'
seconds: StringNumber = '1'
random: bool = False
cmd: str = 'sleep'

Expand All @@ -74,7 +91,7 @@ class WebServCommand(BaseCommand):
type: Literal['webserv']
cmd: str = 'HTTP-GET'
local_path: str
port: str = Field(pattern=VAR_PATTERN, default='8000')
port: StringNumber = '8000'
address: str = '0.0.0.0' # nosec


Expand All @@ -90,7 +107,7 @@ class FatherCommand(BaseCommand):
hiddenport: str = 'D431'
shell_pass: str = 'lobster'
install_path: str = '/lib/selinux.so.3'
local_path: Optional[str]
local_path: Optional[str] = None
arch: Literal['amd64'] = 'amd64'
build_command: str = 'make'

Expand All @@ -116,38 +133,33 @@ class RegExCommand(BaseCommand):


class SSHBase(BaseCommand):
@validator('session')
def session_and_background_unsupported(cls, v, values, **kwargs):
if 'background' in values and values['background']:
@field_validator('session', 'creates_session')
@classmethod
def session_and_background_unsupported(cls, v, info: ValidationInfo) -> str:
if 'background' in info.data and info.data['background']:
raise ValueError('background mode combined with session is unsupported for SSH')
return v

@validator('creates_session')
def creates_session_and_background_unsupported(cls, v, values, **kwargs):
if 'background' in values and values['background']:
raise ValueError('background mode combined with session is unsupported for SSH')
return v

hostname: Optional[str]
port: Optional[str] = Field(pattern=VAR_PATTERN, default=None)
username: Optional[str]
password: Optional[str]
passphrase: Optional[str]
key_filename: Optional[str]
creates_session: Optional[str]
session: Optional[str]
hostname: Optional[str] = None
port: StringNumber = None
username: Optional[str] = None
password: Optional[str] = None
passphrase: Optional[str] = None
key_filename: Optional[str] = None
creates_session: Optional[str] = None
session: Optional[str] = None
clear_cache: bool = False
timeout: float = 60
jmp_hostname: Optional[str]
jmp_port: Optional[str] = Field(pattern=VAR_PATTERN, default=None)
jmp_username: Optional[str]
jmp_hostname: Optional[str] = None
jmp_port: StringNumber = None
jmp_username: Optional[str] = None


class SSHCommand(SSHBase):
type: Literal['ssh']
interactive: bool = False
validate_prompt: bool = True
command_timeout: str = Field(pattern=VAR_PATTERN, default='15')
command_timeout: StringNumber = '15'
prompts: List[str] = ['$ ', '# ', '> ']


Expand All @@ -156,7 +168,7 @@ class SFTPCommand(SSHBase):
cmd: Literal['get', 'put']
remote_path: str
local_path: str
mode: Optional[str]
mode: Optional[str] = None


class MsfSessionCommand(BaseCommand):
Expand All @@ -166,7 +178,7 @@ class MsfSessionCommand(BaseCommand):
write: bool = False
read: bool = False
session: str
end_str: Optional[str]
end_str: Optional[str] = None


class MsfPayloadCommand(BaseCommand):
Expand All @@ -179,19 +191,19 @@ class MsfPayloadCommand(BaseCommand):
template: Optional[str] = None
platform: Optional[str] = None
keep_template_working: bool = False
nopsled_size: str = Field(pattern=VAR_PATTERN, default='0')
iter: str = Field(pattern=VAR_PATTERN, default='0')
nopsled_size: StringNumber = '0'
iter: StringNumber = '0'
payload_options: Dict[str, str] = {}
local_path: Optional[str]
local_path: Optional[str] = None


class MsfModuleCommand(BaseCommand):
cmd: str
type: Literal['msf-module']
target: str = Field(pattern=VAR_PATTERN, default='0')
creates_session: Optional[str]
session: Optional[str]
payload: Optional[str]
target: StringNumber = '0'
creates_session: Optional[str] = None
session: Optional[str] = None
payload: Optional[str] = None
options: Dict[str, str] = {}
payload_options: Dict[str, str] = {}

Expand Down Expand Up @@ -219,10 +231,10 @@ class HttpClientCommand(BaseCommand):
cmd: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'] = 'GET'
url: str
output_headers: bool = False
headers: Optional[Dict[str, str]]
cookies: Optional[Dict[str, str]]
data: Optional[Dict[str, str]]
local_path: Optional[str]
headers: Optional[Dict[str, str]] = None
cookies: Optional[Dict[str, str]] = None
data: Optional[Dict[str, str]] = None
local_path: Optional[str] = None
useragent: str = 'AttackMate'
follow: bool = False
verify: bool = False
Expand All @@ -249,16 +261,16 @@ class SliverHttpsListenerCommand(BaseCommand):
type: Literal['sliver']
cmd: Literal['start_https_listener']
host: str = '0.0.0.0' # nosec
port: str = Field(pattern=VAR_PATTERN, default='443')
port: StringNumber = '443'
domain: str = ''
website: str = ''
acme: bool = False
persistent: bool = False
enforce_otp: bool = True
randomize_jarm: bool = True
long_poll_timeout: str = Field(pattern=VAR_PATTERN, default='1')
long_poll_jitter: str = Field(pattern=VAR_PATTERN, default='2')
timeout: str = Field(pattern=VAR_PATTERN, default='60')
long_poll_timeout: StringNumber = '1'
long_poll_jitter: StringNumber = '3'
timeout: StringNumber = '60'


class SliverGenerateCommand(BaseCommand):
Expand All @@ -278,9 +290,9 @@ class SliverGenerateCommand(BaseCommand):
'SHARED_LIB',
'SHELLCODE'] = 'EXECUTABLE'
name: str
filepath: Optional[str]
filepath: Optional[str] = None
IsBeacon: bool = False
BeaconInterval: str = Field(pattern=VAR_PATTERN, default='120')
BeaconInterval: StringNumber = '120'
RunAtLoad: bool = False
Evasion: bool = False

Expand Down Expand Up @@ -328,7 +340,7 @@ class SliverSessionNETSTATCommand(SliverSessionCommand):
class SliverSessionEXECCommand(SliverSessionCommand):
cmd: Literal['execute']
exe: str
args: Optional[List[str]]
args: Optional[List[str]] = None
output: bool = True


Expand All @@ -344,7 +356,7 @@ class SliverSessionLSCommand(SliverSessionCommand):
class SliverSessionPROCDUMPCommand(SliverSessionCommand):
cmd: Literal['process_dump']
local_path: str
pid: str = Field(pattern=VAR_PATTERN)
pid: StringNumber


class SliverSessionRMCommand(SliverSessionCommand):
Expand All @@ -356,7 +368,7 @@ class SliverSessionRMCommand(SliverSessionCommand):

class SliverSessionTERMINATECommand(SliverSessionCommand):
cmd: Literal['terminate']
pid: str = Field(pattern=VAR_PATTERN)
pid: StringNumber
force: bool = False


Expand Down

0 comments on commit 018afad

Please sign in to comment.