Skip to content

Commit

Permalink
Merge pull request #242 from Axonius/bugfix/user_schema
Browse files Browse the repository at this point in the history
5.0.1
  • Loading branch information
Jim Olsen committed May 3, 2023
2 parents e1cb134 + 1325ff0 commit d809113
Show file tree
Hide file tree
Showing 33 changed files with 524 additions and 309 deletions.
101 changes: 71 additions & 30 deletions axonius_api_client/api/asset_callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,18 @@
)


# noinspection SpellCheckingInspection
def crjoin(value):
"""Pass."""
joiner = "\n - "
return joiner + joiner.join(value)


# noinspection PyProtectedMember,PyAttributeOutsideInit
class Base:
"""Callbacks for formatting asset data.
Examples:
Create a ``client`` using :obj:`axonius_api_client.connect.Connect` and assume
``apiobj`` is either ``client.devices`` or ``client.users``
>>> apiobj = client.devices # or client.users
* :meth:`args_map` for callback generic arguments to format assets.
* :meth:`args_map_custom` for callback specific arguments to format and export data.
Expand All @@ -66,7 +63,11 @@ def args_map(cls) -> dict:
Create a ``client`` using :obj:`axonius_api_client.connect.Connect` and assume
``apiobj`` is either ``client.devices`` or ``client.users``
>>> apiobj = client.devices # or client.users
>>> import axonius_api_client as axonapi
>>> connect_args: dict = axonapi.get_env_connect()
>>> client: axonapi.Connect = axonapi.Connect(**connect_args)
>>> apiobj: axonapi.api.assets.AssetMixin = client.devices
>>> # or client.users or client.vulnerabilities
Flatten complex fields - Will take all sub-fields of complex fields and put them
on the root level with their values index correlated to each other.
Expand Down Expand Up @@ -182,7 +183,9 @@ def args_map_base(cls) -> dict:
"field_null_value": None,
"field_null_value_complex": [],
"tags_add": [],
"tags_add_invert_selection": False,
"tags_remove": [],
"tags_remove_invert_selection": False,
"report_adapters_missing": False,
"report_software_whitelist": [],
"page_progress": 10000,
Expand All @@ -196,15 +199,15 @@ def args_map_base(cls) -> dict:
"csv_field_null": True,
}

def get_arg_value(self, arg: str) -> Union[str, list, bool, int]:
def get_arg_value(self, arg: str) -> t.Any:
"""Get an argument value.
Args:
arg: key to get from :attr:`GETARGS` with a default value from :meth:`args_map`
"""
return self.GETARGS.get(arg, self.args_map()[arg])

def set_arg_value(self, arg: str, value: Union[str, list, bool, int]):
def set_arg_value(self, arg: str, value: t.Any):
"""Set an argument value.
Args:
Expand Down Expand Up @@ -267,15 +270,16 @@ def start(self, **kwargs):
self.echo(msg=f"Adding fields {missing} to field_excludes: {excludes}", debug=True)
self.set_arg_value("field_excludes", value=excludes + missing)

cbargs = crjoin(join_kv(obj=self.GETARGS))
self.LOG.debug(f"Get Extra Arguments: {cbargs}")
cb_args = crjoin(join_kv(obj=self.GETARGS))
self.LOG.debug(f"Get Extra Arguments: {cb_args}")

config = crjoin(self.args_strs)
self.echo(msg=f"Configuration: {config}")

store = crjoin(join_kv(obj=self.STORE))
self.echo(msg=f"Get Arguments: {store}")

# noinspection PyUnusedLocal
def echo_columns(self, **kwargs):
"""Echo the columns of the fields selected."""
if getattr(self, "ECHO_DONE", False):
Expand Down Expand Up @@ -382,6 +386,8 @@ def do_row(self, rows: Union[List[dict], dict]) -> List[dict]:
"""
debug_timing = self.get_arg_value("debug_timing")

p_start = None
cb_start = None
if debug_timing: # pragma: no cover
p_start = dt_now()

Expand All @@ -391,11 +397,11 @@ def do_row(self, rows: Union[List[dict], dict]) -> List[dict]:

rows = cb(rows=rows)
# print(f"{cb} {json_dump(rows)}")
if debug_timing: # pragma: no cover
if debug_timing and cb_start: # pragma: no cover
cb_delta = dt_now() - cb_start
self.LOG.debug(f"CALLBACK {cb} took {cb_delta} for {len(rows)} rows")

if debug_timing: # pragma: no cover
if debug_timing and p_start: # pragma: no cover
p_delta = dt_now() - p_start
self.LOG.debug(f"CALLBACKS TOOK {p_delta} for {len(rows)} rows")

Expand Down Expand Up @@ -531,7 +537,8 @@ def _do_join_values(self, row: dict):
if trim_len and isinstance(value, str) and len(value) >= trim_len:
field_len = len(value)
msg = trim_str.format(field_len=field_len, trim_len=trim_len)
row[field] = value = joiner.join([value[:trim_len], msg])
value = [value[:trim_len], msg]
row[field] = joiner.join(value)

def do_change_field_replace(self, rows: Union[List[dict], dict]) -> List[dict]:
"""Asset callback to replace characters.
Expand All @@ -558,6 +565,7 @@ def field_replacements(self) -> List[Tuple[str, str]]:
"""Parse the supplied list of field name replacements."""

def parse_replace(replace):
"""Parse the supplied list of field name replacements."""
if isinstance(replace, str):
replace = replace.split("=", maxsplit=1)

Expand Down Expand Up @@ -593,7 +601,6 @@ def _field_compress(self, key: str) -> str:
return key

splits = key.split(".")
prefix = ""

if splits[0] == "specific_data":
prefix = AGG_ADAPTER_NAME
Expand Down Expand Up @@ -715,6 +722,7 @@ def _do_explode_entities(self, row: dict) -> List[dict]:
"""

def explode(idx: int, adapter: str) -> dict:
"""Explode a row into a row for each asset entity."""
new_row = {"adapters": adapter}

for k, v in row.items():
Expand Down Expand Up @@ -796,19 +804,46 @@ def do_tagging(self):

def do_tag_add(self):
"""Add tags to assets."""
tags_add = listify(self.get_arg_value("tags_add"))
rows_add = self.TAG_ROWS_ADD
if tags_add and rows_add:
self.echo(msg=f"Adding tags {tags_add} to {len(rows_add)} assets")
self.APIOBJ.labels.add(rows=rows_add, labels=tags_add)
tags = listify(self.get_arg_value("tags_add"))
rows = self.TAG_ROWS_ADD
invert_selection = self.get_arg_value("tags_add_invert_selection")
count_tags = len(tags)
count_supplied = len(rows)
msgs = [
f" Tags supplied ({count_tags}): {tags}",
f" Asset IDs supplied ({count_supplied})",
f" Invert selection: {invert_selection}",
]
if tags:
self.echo(["Performing API call to add tags to assets", *msgs])
count_modified = self.APIOBJ.labels.add(
rows=rows, labels=tags, invert_selection=invert_selection
)
self.echo(msg=[f"API added tags to {count_modified} assets", *msgs])

def do_tag_remove(self):
"""Remove tags from assets."""
tags_remove = listify(self.get_arg_value("tags_remove"))
rows_remove = self.TAG_ROWS_REMOVE
if tags_remove and rows_remove:
self.echo(msg=f"Removing tags {tags_remove} from {len(rows_remove)} assets")
self.APIOBJ.labels.remove(rows=rows_remove, labels=tags_remove)
tags = listify(self.get_arg_value("tags_remove"))
rows = self.TAG_ROWS_REMOVE
invert_selection = self.get_arg_value("tags_remove_invert_selection")
count_tags = len(tags)
count_supplied = len(rows)
msgs = [
f" Asset IDs supplied ({count_supplied})",
f" Tags supplied ({count_tags}): {tags}",
f" Invert selection: {invert_selection}",
]
if tags:
self.echo(["Performing API call to remove tags from assets", *msgs])
count_modified = self.APIOBJ.labels.remove(
rows=rows, labels=tags, invert_selection=invert_selection
)
msgs = [
f"API finished removing tags from assets",
f" Asset IDs modified: " f"{count_modified}",
*msgs,
]
self.echo(msg=msgs)

def process_tags_to_add(self, rows: Union[List[dict], dict]) -> List[dict]:
"""Add assets to tracker for adding tags.
Expand Down Expand Up @@ -983,9 +1018,9 @@ def excluded_schemas(self) -> List[dict]:

def echo(
self,
msg: str,
msg: t.Union[str, t.List[str]],
debug: bool = False,
error: Union[bool, t.Type[Exception]] = False,
error: Union[bool, str, t.Type[Exception]] = False,
warning: bool = False,
level: str = "info",
level_debug: str = "debug",
Expand All @@ -1000,12 +1035,14 @@ def echo(
error: message is an error
warning: message is a warning
level: logging level for non error/non warning messages
level_debug: logging level for debug messages
level_error: logging level for error messages
level_warning: logging level for warning messages
abort: sys.exit(1) if error is true
debug: message is a debug message
"""
do_echo = self.get_arg_value("do_echo")

msg = "\n".join(listify(msg))
if do_echo:
if error:
echo_error(msg=msg, abort=abort)
Expand Down Expand Up @@ -1086,6 +1123,7 @@ def final_columns(self) -> List[str]:
"""Get the columns that will be returned."""

def get_key(s):
"""Get the key for a schema."""
return self._field_replace(self._field_compress(s[key]))

if hasattr(self, "_final_columns"):
Expand Down Expand Up @@ -1167,7 +1205,7 @@ def schema_to_explode(self) -> dict:
def adapter_map(self) -> dict:
"""Build a map of adapters that have connections."""
if getattr(self, "_adapter_map", None):
return self._adapter_map
return getattr(self, "_adapter_map", None)

self._adapters_meta = getattr(
self, "_adapters_meta", self.APIOBJ.adapters.get(get_clients=False)
Expand Down Expand Up @@ -1232,7 +1270,7 @@ def __repr__(self) -> str:
STORE: dict = None
"""store dict used by get assets method to track arguments."""

CURRENT_ROWS: None
CURRENT_ROWS: t.Optional[t.List[dict]] = None
"""current rows being processed"""

GETARGS: dict = None
Expand All @@ -1248,6 +1286,7 @@ def __repr__(self) -> str:
"""tracker of custom callbacks that have been executed by :meth:`do_custom_cbs`"""


# noinspection PyAttributeOutsideInit
class ExportMixins(Base):
"""Export mixins for callbacks."""

Expand Down Expand Up @@ -1316,7 +1355,7 @@ def open_fd_path(self) -> IO:
debug=True,
)
elif not export_overwrite:
msg = f"Export file {str(self._file_path)!r} already exists and overwite is False!"
msg = f"Export file {str(self._file_path)!r} already exists and overwrite is False!"
self.echo(msg=msg, error=ApiError, level="error")
else:
self._file_mode: str = "Overwrote existing file"
Expand Down Expand Up @@ -1410,7 +1449,9 @@ def arg_export_fd_close(self) -> bool:
"field_null_value": "Null value to use for missing simple fields",
"field_null_value_complex": "Null value to use for missing complex fields",
"tags_add": "Tags to add to assets",
"tags_add_invert_selection": "Invert selection for tags to add",
"tags_remove": "Tags to remove from assets",
"tags_remove_invert_selection": "Invert selection for tags to remove",
"report_adapters_missing": "Add Missing Adapters calculation",
"report_software_whitelist": "Missing Software to calculate",
"page_progress": "Echo page progress every N assets",
Expand Down
5 changes: 4 additions & 1 deletion axonius_api_client/api/assets/asset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,7 @@ def get_generator(
fields_default: include the default fields in :attr:`fields_default`
fields_root: include all fields of an adapter that are not complex sub-fields
fields_error: throw validation errors on supplied fields
fields_parsed: previously parsed fields
max_rows: only return N rows
max_pages: only return N pages
row_start: start at row N
Expand All @@ -1172,7 +1173,6 @@ def get_generator(
wiz_entries: wizard expressions to create query from
file_date: string to use in filename templates for {DATE}
wiz_parsed: parsed output from a query wizard
fields_parsed: previously parsed fields
sort_field_parsed: previously parsed sort field
history_date_parsed: previously parsed history date
initial_count: previously fetched initial count
Expand Down Expand Up @@ -1823,6 +1823,9 @@ def _init(self, **kwargs):
self.labels: Labels = Labels(parent=self)
"""Work with labels (tags)."""

self.tags = self.labels
"""Alias for :attr:`labels`."""

self.saved_query: SavedQuery = SavedQuery(parent=self)
"""Work with saved queries."""

Expand Down

0 comments on commit d809113

Please sign in to comment.