From c132af1037889ed37d7748af0d6fe071959d77f8 Mon Sep 17 00:00:00 2001 From: Wojciech Boncela Date: Tue, 11 Nov 2025 16:22:19 +0000 Subject: [PATCH] fix: Validate target profile for 'mlia optimize' command. * A new validate_optimize_target_profile function was added. Signed-off-by: Wojciech Boncela Change-Id: Icba2c69efdf0baa65bc6f1a5941ac11e3201743c Reviewed-on: https://eu-gerrit-2.euhpc.arm.com/c/ml/ecosystem/mlia/+/1146727 Tested-by: expkit Reviewed-by: Mike Kelly IP-review: Mike Kelly (cherry picked from commit 9fccfab1c0ad47a2dc2ad1f9d7f8c8d121c7ca3d) Reviewed-on: https://eu-gerrit-2.euhpc.arm.com/c/ml/ecosystem/mlia/+/1151307 Reviewed-by: Isabella Gottardi IP-review: Isabella Gottardi --- src/mlia/cli/command_validators.py | 16 +++++++++++++++- src/mlia/cli/commands.py | 4 +++- tests/test_cli_command_validators.py | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py index 04e04b1..867a645 100644 --- a/src/mlia/cli/command_validators.py +++ b/src/mlia/cli/command_validators.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2023, 2025, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI command validators module.""" from __future__ import annotations @@ -102,6 +102,20 @@ def validate_check_target_profile(target_profile: str, category: set[str]) -> No sys.exit(0) +def validate_optimize_target_profile(target_profile: str) -> None: + """Validate whether the provided target profile is compatible with 'mlia optimize'. + + This function exits with code 1 if the provided target profile is + not supported. + """ + incompatible_targets_optimize: list[str] = ["tosa", "cortex-a"] + if target_profile in incompatible_targets_optimize: + logger.error( + "Optimization cannot be performed with target profile %s.", target_profile + ) + sys.exit(1) + + def normalize_string(value: str) -> str: """Given a string return the normalized version. diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index fcba302..b6471c1 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2025, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI commands module. @@ -26,6 +26,7 @@ from mlia.backend.manager import get_installation_manager from mlia.cli.command_validators import validate_backend from mlia.cli.command_validators import validate_check_target_profile +from mlia.cli.command_validators import validate_optimize_target_profile from mlia.cli.options import parse_optimization_parameters from mlia.utils.console import create_section_header @@ -160,6 +161,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments ) ) + validate_optimize_target_profile(target_profile) validated_backend = validate_backend(target_profile, backend) get_advice( diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py index 302cd4e..9f35229 100644 --- a/tests/test_cli_command_validators.py +++ b/tests/test_cli_command_validators.py @@ -11,6 +11,7 @@ from mlia.cli.command_validators import validate_backend from mlia.cli.command_validators import validate_check_target_profile +from mlia.cli.command_validators import validate_optimize_target_profile @pytest.mark.parametrize( @@ -192,3 +193,21 @@ def test_validate_backend_default_unavailable(monkeypatch: pytest.MonkeyPatch) - ) with pytest.raises(argparse.ArgumentError): validate_backend("cortex-a", None) + + +@pytest.mark.parametrize( + "target_profile, sys_exit", + [("ethos-u55-128", False), ("tosa", True), ("cortex-a", True)], +) +def test_validate_optimize_target_profile( + target_profile: str, + sys_exit: bool, +) -> None: + """Tests if an incompatible target is passed for optimization.""" + if sys_exit: + with pytest.raises(SystemExit) as sys_ex: + validate_optimize_target_profile(target_profile) + assert sys_ex.value.code == 1 + return + + validate_optimize_target_profile(target_profile)