Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mlia/cli/command_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def validate_check_target_profile(target_profile: str, category: set[str]) -> No
)

# Case: compatibility operation to be skipped
if try_compatibility and not do_compatibility:
if try_compatibility and not do_compatibility: # pragma: no cover, defensive code
warning_message += (
"Compatibility checks skipped as they cannot be "
f"performed with target profile {target_profile}."
Expand Down
2 changes: 1 addition & 1 deletion src/mlia/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,5 @@ def backend_main(argv: list[str] | None = None) -> int:
return init_and_run(commands, argv)


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
sys.exit(main())
83 changes: 82 additions & 1 deletion tests/test_cli_command_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest

from mlia.cli.command_validators import normalize_string
from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.command_validators import validate_optimize_target_profile
Expand Down Expand Up @@ -185,7 +186,9 @@ def test_validate_backend_default_available() -> None:
assert backends == ["armnn-tflite-delegate"]


def test_validate_backend_default_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_backend_default_unavailable(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test default backend validation with unavailable backend."""
monkeypatch.setattr(
"mlia.cli.command_validators.default_backends",
Expand All @@ -211,3 +214,81 @@ def test_validate_optimize_target_profile(
return

validate_optimize_target_profile(target_profile)


@pytest.mark.parametrize(
"input_string, expected_output",
[
["", ""],
["lowercase", "lowercase"],
["UPPERCASE", "uppercase"],
["VELA", "vela"],
["check-no-hyphens", "checknohyphens"],
["MixedCase-With-Hyphens", "mixedcasewithhyphens"],
["ToSa-cHecker", "tosachecker"],
["corstone-310", "corstone310"],
["armnn-tflite-delegate", "armnntflitedelegate"],
["---multiple---hyphens---", "multiplehyphens"],
],
)
def test_normalize_string(input_string: str, expected_output: str) -> None:
"""Test normalize_string function with various inputs."""
assert normalize_string(input_string) == expected_output


@pytest.mark.parametrize(
"supported_backends, target, target_profile, backends, expected",
[
(
["armnn-tflite-delegate"],
"cortex-a",
"cortex-a",
["armnn-tflite-delegate"],
["armnn-tflite-delegate"],
),
(
["Vela", "Corstone-310"],
"ethos-u55",
"ethos-u55-256",
["VELA", "corstone-310"],
["VELA", "corstone-310"],
),
],
ids=["hyphen_normalization", "case_insensitive"],
)
def test_validate_backend_normalization(
supported_backends: list[str],
target: str,
target_profile: str,
backends: list[str],
expected: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test backend validation with hyphen and case normalization."""
monkeypatch.setattr(
"mlia.cli.command_validators.supported_backends",
MagicMock(return_value=supported_backends),
)
monkeypatch.setattr(
"mlia.cli.command_validators.get_target",
MagicMock(return_value=target),
)

result = validate_backend(target_profile, backends)
assert result == expected


def test_validate_backend_multiple_incompatible(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test validate_backend with multiple incompatible backends."""
monkeypatch.setattr(
"mlia.cli.command_validators.supported_backends",
MagicMock(return_value=["vela"]),
)
monkeypatch.setattr(
"mlia.cli.command_validators.get_target",
MagicMock(return_value="ethos-u55"),
)
with pytest.raises(argparse.ArgumentError, match="not supported"):
validate_backend("ethos-u55-256", ["tosa-checker", "corstone-320"])
60 changes: 57 additions & 3 deletions tests/test_cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,65 @@ def test_backend_command_action_install_from_path(
installation_manager_mock.install_from.assert_called_once()


def test_backend_command_action_add_download_invalid_names(
def test_backend_command_action_install_no_names_with_path(
installation_manager_mock: MagicMock,
tmp_path: Path,
) -> None:
"""Test mlia-backend command "install" with invalid backend names."""
with pytest.raises(ValueError):
"""Test backend_install raises ValueError with no names & path."""
with pytest.raises(ValueError, match="backend name"):
backend_install(path=tmp_path, names=[])
installation_manager_mock.install_from.assert_not_called()


def test_backend_command_action_install_multiple_names_with_path(
installation_manager_mock: MagicMock,
tmp_path: Path,
) -> None:
"""Test backend_install raises ValueError when multiple names & path."""
with pytest.raises(ValueError, match="backend name"):
backend_install(path=tmp_path, names=["backend1", "backend2"])
installation_manager_mock.install_from.assert_not_called()


@pytest.mark.parametrize(
"compatibility, performance, expected_category",
[
[True, True, {"compatibility", "performance"}],
[True, False, {"compatibility"}],
[False, True, {"performance"}],
[False, False, {"compatibility"}],
],
)
def test_check_category_combinations(
sample_context: ExecutionContext,
test_tflite_model: Path,
monkeypatch: pytest.MonkeyPatch,
compatibility: bool,
performance: bool,
expected_category: set[str],
) -> None:
"""Test check() with different category combinations."""
mock_performance_estimation(monkeypatch)

# Mock get_advice to capture what category is passed
get_advice_mock = MagicMock()
monkeypatch.setattr("mlia.cli.commands.get_advice", get_advice_mock)

# Mock validators
monkeypatch.setattr("mlia.cli.commands.validate_check_target_profile", MagicMock())
monkeypatch.setattr(
"mlia.cli.commands.validate_backend", MagicMock(return_value=None)
)

check(
sample_context,
target_profile="ethos-u55-256",
model=str(test_tflite_model),
compatibility=compatibility,
performance=performance,
)

# Verify get_advice was called with the expected category
get_advice_mock.assert_called_once()
call_args = get_advice_mock.call_args
assert call_args[0][2] == expected_category
70 changes: 69 additions & 1 deletion tests/test_cli_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright 2022-2025, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from __future__ import annotations
Expand Down Expand Up @@ -77,6 +77,30 @@ def test_apply_optimizations(
resolver = CLIActionResolver(args)
assert resolver.apply_optimizations(**params) == expected_result

@staticmethod
def test_apply_optimizations_multiple_settings() -> None:
"""Test apply_optimizations with multiple optimization settings."""
args = {"model": "model.h5", "target_profile": "ethos-u55-256"}
opt_settings = [
OptimizationSettings("pruning", 0.6, None),
OptimizationSettings("clustering", 32, None),
]

resolver = CLIActionResolver(args)
result = resolver.apply_optimizations(opt_settings=opt_settings)

assert len(result) == 2
assert result[0] == "For more info: mlia optimize --help"
# Check that all optimization types are included
assert "--pruning" in result[1]
assert "--clustering" in result[1]
# Check that all targets are included
assert "--pruning-target 0.6" in result[1]
assert "--clustering-target 32" in result[1]
# Check model and target profile
assert "model.h5" in result[1]
assert "--target-profile ethos-u55-256" in result[1]

@staticmethod
def test_operator_compatibility_details() -> None:
"""Test operator compatibility details info."""
Expand Down Expand Up @@ -180,3 +204,47 @@ def test_copy_optimization_file_to_output_dir_error(tmp_path: Path) -> None:
copy_profile_file_to_output_dir(
test_target_profile_name, tmp_path, profile_to_copy="optimization_profile"
)


def test_copy_custom_profile_file_to_output_dir(tmp_path: Path) -> None:
"""Test copying a user-provided target profile file to output directory."""
# Create a custom profile file
custom_profile_dir = tmp_path / "custom_profiles"
custom_profile_dir.mkdir()
custom_profile_file = custom_profile_dir / "my_custom_profile.toml"
custom_profile_file.write_text("[target]\nname = 'custom'\n")

# Copy it to output directory
output_dir = tmp_path / "output"
output_dir.mkdir()

copy_profile_file_to_output_dir(
custom_profile_file, output_dir, profile_to_copy="target_profile"
)

# Verify the file was copied
output_file = output_dir / "my_custom_profile.toml"
assert output_file.is_file()
assert output_file.read_text() == "[target]\nname = 'custom'\n"


def test_copy_custom_optimization_profile_to_output_dir(tmp_path: Path) -> None:
"""Test copying a user-provided optimization profile file to output directory."""
# Create a custom optimization profile file
custom_profile_dir = tmp_path / "custom_profiles"
custom_profile_dir.mkdir()
custom_profile_file = custom_profile_dir / "my_optimization.toml"
custom_profile_file.write_text("[optimization]\ntype = 'custom'\n")

# Copy it to output directory
output_dir = tmp_path / "output"
output_dir.mkdir()

copy_profile_file_to_output_dir(
custom_profile_file, output_dir, profile_to_copy="optimization_profile"
)

# Verify the file was copied
output_file = output_dir / "my_optimization.toml"
assert output_file.is_file()
assert output_file.read_text() == "[optimization]\ntype = 'custom'\n"
Loading