Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/mlia/cli/command_validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright 2023, 2025, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI command validators module."""
from __future__ import annotations
Expand Down Expand Up @@ -102,6 +102,20 @@ def validate_check_target_profile(target_profile: str, category: set[str]) -> No
sys.exit(0)


def validate_optimize_target_profile(target_profile: str) -> None:
"""Validate whether the provided target profile is compatible with 'mlia optimize'.

This function exits with code 1 if the provided target profile is
not supported.
"""
incompatible_targets_optimize: list[str] = ["tosa", "cortex-a"]
if target_profile in incompatible_targets_optimize:
logger.error(
"Optimization cannot be performed with target profile %s.", target_profile
)
sys.exit(1)


def normalize_string(value: str) -> str:
"""Given a string return the normalized version.

Expand Down
4 changes: 3 additions & 1 deletion src/mlia/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright 2022-2025, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI commands module.

Expand Down Expand Up @@ -26,6 +26,7 @@
from mlia.backend.manager import get_installation_manager
from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.command_validators import validate_optimize_target_profile
from mlia.cli.options import parse_optimization_parameters
from mlia.utils.console import create_section_header

Expand Down Expand Up @@ -160,6 +161,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments
)
)

validate_optimize_target_profile(target_profile)
validated_backend = validate_backend(target_profile, backend)

get_advice(
Expand Down
19 changes: 19 additions & 0 deletions tests/test_cli_command_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.command_validators import validate_optimize_target_profile


@pytest.mark.parametrize(
Expand Down Expand Up @@ -192,3 +193,21 @@ def test_validate_backend_default_unavailable(monkeypatch: pytest.MonkeyPatch) -
)
with pytest.raises(argparse.ArgumentError):
validate_backend("cortex-a", None)


@pytest.mark.parametrize(
"target_profile, sys_exit",
[("ethos-u55-128", False), ("tosa", True), ("cortex-a", True)],
)
def test_validate_optimize_target_profile(
target_profile: str,
sys_exit: bool,
) -> None:
"""Tests if an incompatible target is passed for optimization."""
if sys_exit:
with pytest.raises(SystemExit) as sys_ex:
validate_optimize_target_profile(target_profile)
assert sys_ex.value.code == 1
return

validate_optimize_target_profile(target_profile)