From db360702db21259c2f6e3ded2226cc07cd22316e Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Wed, 19 Nov 2025 05:47:50 +0000 Subject: [PATCH 1/6] Support continuing tests with failure. overnight tests should not fail due to one bad node should continue to test other nodes to qualify eligible nodes. stop_on_errors is an optional arg supported by parallelssh library defaults to True. Change is to allow callers of Pssh instance to pass optional stop_on_errors and pass it to run_command api in exec and exec_cmd_list methods. Signed-off-by: Ignatious Johnson --- lib/parallel_ssh_lib.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/parallel_ssh_lib.py b/lib/parallel_ssh_lib.py index 8b031f1c..870e4825 100755 --- a/lib/parallel_ssh_lib.py +++ b/lib/parallel_ssh_lib.py @@ -28,7 +28,7 @@ class Pssh(): mandatory args = user, password (or) 'private_key': load_private_key('my_key.pem') """ - def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False ): + def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False, stop_on_errors=True ): self.log = log self.host_list = host_list @@ -36,6 +36,7 @@ def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host self.pkey = pkey self.password = password self.host_key_check = host_key_check + self.stop_on_errors = stop_on_errors if self.password is None: print(self.host_list) @@ -53,9 +54,9 @@ def exec(self, cmd, timeout=None ): cmd_output = {} print(f'cmd = {cmd}') if timeout is None: - output = self.client.run_command(cmd ) + output = self.client.run_command(cmd, stop_on_errors=self.stop_on_errors ) else: - output = self.client.run_command(cmd, read_timeout=timeout ) + output = self.client.run_command(cmd, read_timeout=timeout, stop_on_errors=self.stop_on_errors ) for item in output: print('#----------------------------------------------------------#') print(f'Host == {item.host} ==') @@ -71,6 +72,10 @@ def exec(self, cmd, timeout=None ): print(line) cmd_out_str = cmd_out_str + line.replace( '\t', ' ') cmd_out_str = cmd_out_str + '\n' + if item.exception: + exc_str = str(item.exception).replace('\t', ' ') + print(exc_str) + cmd_out_str += exc_str + '\n' cmd_output[item.host] = cmd_out_str return cmd_output @@ -85,9 +90,9 @@ def exec_cmd_list(self, cmd_list, timeout=None ): cmd_output = {} print(cmd_list) if timeout is None: - output = self.client.run_command( '%s', host_args=cmd_list ) + output = self.client.run_command( '%s', host_args=cmd_list, stop_on_errors=self.stop_on_errors ) else: - output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout ) + output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout, stop_on_errors=self.stop_on_errors ) i = 0 for item in output: print('#----------------------------------------------------------#') @@ -105,6 +110,10 @@ def exec_cmd_list(self, cmd_list, timeout=None ): print(line) cmd_out_str = cmd_out_str + line.replace( '\t', ' ') cmd_out_str = cmd_out_str + '\n' + if item.exception: + exc_str = str(item.exception).replace('\t', ' ') + print(exc_str) + cmd_out_str += exc_str + '\n' i=i+1 cmd_output[item.host] = cmd_out_str @@ -126,7 +135,7 @@ def scp_file(self, local_file, remote_file, recurse=False ): def reboot_connections(self ): print('Rebooting Connections') - self.client.run_command( 'reboot -f' ) + self.client.run_command( 'reboot -f', stop_on_errors=self.stop_on_errors ) def destroy_clients(self ): print('Destroying Current phdl connections ..') From 5542f07e8930e25e6b8f4e2ec6061313898b1052 Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Wed, 19 Nov 2025 06:13:49 +0000 Subject: [PATCH 2/6] Pass stop_on_errors=False to aghfc and rvs tests, so that these tests will continue to run overnight even if one of the node is unresponsive. Signed-off-by: Ignatious Johnson --- tests/health/agfhc_cvs.py | 2 +- tests/health/rvs_cvs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/health/agfhc_cvs.py b/tests/health/agfhc_cvs.py index 78c7abb4..f9795eeb 100644 --- a/tests/health/agfhc_cvs.py +++ b/tests/health/agfhc_cvs.py @@ -136,7 +136,7 @@ def phdl(cluster_dict): nhdl_dict = {} print(cluster_dict) node_list = list(cluster_dict['node_dict'].keys()) - phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'] ) + phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'], stop_on_errors=False ) return phdl diff --git a/tests/health/rvs_cvs.py b/tests/health/rvs_cvs.py index 6306bda0..c159c761 100644 --- a/tests/health/rvs_cvs.py +++ b/tests/health/rvs_cvs.py @@ -66,7 +66,7 @@ def config_dict(config_file, cluster_dict): def phdl(cluster_dict): print(cluster_dict) node_list = list(cluster_dict['node_dict'].keys()) - phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'] ) + phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'], stop_on_errors=False ) return phdl From 58964403d97ab0474884e10efce47cf29331ca99 Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Wed, 5 Nov 2025 10:08:51 -0800 Subject: [PATCH 3/6] Adding sample Unittests for lib module Signed-off-by: Ignatious Johnson --- UNIT_TESTING_GUIDE.md | 92 +++++++++++++++++ lib/__init__.py | 0 lib/html_lib.py | 2 +- lib/linux_utils.py | 4 +- lib/rocm_plib.py | 4 +- lib/unittests/__init__.py | 0 lib/unittests/test_html_lib.py | 172 +++++++++++++++++++++++++++++++ lib/unittests/test_verify_lib.py | 91 ++++++++++++++++ lib/utils_lib.py | 2 +- lib/verify_lib.py | 6 +- run_all_unittests.py | 12 +++ 11 files changed, 376 insertions(+), 9 deletions(-) create mode 100644 UNIT_TESTING_GUIDE.md create mode 100644 lib/__init__.py create mode 100644 lib/unittests/__init__.py create mode 100644 lib/unittests/test_html_lib.py create mode 100644 lib/unittests/test_verify_lib.py create mode 100644 run_all_unittests.py diff --git a/UNIT_TESTING_GUIDE.md b/UNIT_TESTING_GUIDE.md new file mode 100644 index 00000000..511ca1b5 --- /dev/null +++ b/UNIT_TESTING_GUIDE.md @@ -0,0 +1,92 @@ + +# Unit Test Organization + +This project separates **unit tests** and **cluster validation tests** to maintain clarity and modularity. + +--- + +## ๐Ÿ“ Directory Structure + +``` +cvs/ +โ”œโ”€โ”€ lib/ +โ”‚ โ”œโ”€โ”€ docker_lib.py +โ”‚ โ”œโ”€โ”€ utils_lib.py +โ”‚ โ”œโ”€โ”€ verify_lib.py +โ”‚ โ””โ”€โ”€ unittests/ # Unit tests for lib modules +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ test_verify_lib.py +โ”‚ โ””โ”€โ”€ test_html_lib.py +โ”œโ”€โ”€ tests/ # cluster validation tests (Pytest) +โ”‚ โ”œโ”€โ”€ health +โ”‚ โ”œโ”€โ”€ ibperf +โ”‚ โ”œโ”€โ”€ rccl +โ”œโ”€โ”€ conftest.py +โ”œโ”€โ”€ pytest.ini +โ”œโ”€โ”€ README.md +``` + +--- + +## โœ… Unit Tests (using `unittest`) + +- Use Python's built-in `unittest` framework +- Each test file should be named `test_*.py` +- Each test class should inherit from `unittest.TestCase` + +### ๐Ÿ”ง How to Run Unit Tests + +From the project root (`cvs/`): + +```bash +python -m unittest discover -s lib/unittests +``` + +This will discover and run all unit tests under `lib/unittests/`. + +--- + +## ๐Ÿ› ๏ธ Run All Unit Tests from Multiple Folders + +If you have multiple unit test folders (e.g., `lib/unittests`, `utils/unittests`), create a script like: + +Run it with: + +```bash +python run_all_unittests.py +test_missing_gpu_key_raises_keyerror (test_html_lib.TestBuildHtmlMemUtilizationTable.test_missing_gpu_key_raises_keyerror) ... Build HTML mem utilization table +ok +test_multiple_nodes (test_html_lib.TestBuildHtmlMemUtilizationTable.test_multiple_nodes) ... Build HTML mem utilization table +ok +test_rocm7_style_gpu_data (test_html_lib.TestBuildHtmlMemUtilizationTable.test_rocm7_style_gpu_data) ... Build HTML mem utilization table +ok +test_single_node_valid_input (test_html_lib.TestBuildHtmlMemUtilizationTable.test_single_node_valid_input) ... Build HTML mem utilization table +ok +test_bytes_only (test_html_lib.TestNormalizeBytes.test_bytes_only) ... ok +test_gigabytes (test_html_lib.TestNormalizeBytes.test_gigabytes) ... ok +test_kilobytes_binary (test_html_lib.TestNormalizeBytes.test_kilobytes_binary) ... ok +test_kilobytes_decimal (test_html_lib.TestNormalizeBytes.test_kilobytes_decimal) ... ok +test_megabytes (test_html_lib.TestNormalizeBytes.test_megabytes) ... ok +test_negative_bytes (test_html_lib.TestNormalizeBytes.test_negative_bytes) ... ok +test_precision (test_html_lib.TestNormalizeBytes.test_precision) ... ok +test_type_error (test_html_lib.TestNormalizeBytes.test_type_error) ... ok +test_invalid_bus_speed (test_verify_lib.TestVerifyGpuPcieBusWidth.test_invalid_bus_speed) ... ok +test_valid_bus_width (test_verify_lib.TestVerifyGpuPcieBusWidth.test_valid_bus_width) ... ok +test_threshold_exceeded (test_verify_lib.TestVerifyGpuPcieErrors.test_threshold_exceeded) ... ok +test_valid_error_metrics (test_verify_lib.TestVerifyGpuPcieErrors.test_valid_error_metrics) ... ok + +---------------------------------------------------------------------- +Ran 16 tests in 0.026s + +OK +``` + +--- + +## ๐Ÿงช Tips for Organizing Tests + +- Keep unit tests close to the code they test (e.g., `lib/unittests/`) +- Use `__init__.py` in all test directories to make them importable +- Use `unittest.mock` to isolate unit tests from external dependencies + +--- diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/html_lib.py b/lib/html_lib.py index e24ac1c7..f6078a8a 100644 --- a/lib/html_lib.py +++ b/lib/html_lib.py @@ -10,7 +10,7 @@ import re import json -from rocm_plib import * +from .rocm_plib import * def build_html_page_header(filename): diff --git a/lib/linux_utils.py b/lib/linux_utils.py index 4ab6f214..5144f167 100644 --- a/lib/linux_utils.py +++ b/lib/linux_utils.py @@ -9,9 +9,9 @@ import sys import os import json -import rocm_plib +from . import rocm_plib -from utils_lib import * +from .utils_lib import * diff --git a/lib/rocm_plib.py b/lib/rocm_plib.py index ee57ae7b..a73aee1c 100644 --- a/lib/rocm_plib.py +++ b/lib/rocm_plib.py @@ -8,9 +8,9 @@ import re import os import sys -import parallel_ssh_lib +from . import parallel_ssh_lib -from utils_lib import * +from .utils_lib import * def get_rocm_smi_dict( phdl ): diff --git a/lib/unittests/__init__.py b/lib/unittests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/unittests/test_html_lib.py b/lib/unittests/test_html_lib.py new file mode 100644 index 00000000..e259966a --- /dev/null +++ b/lib/unittests/test_html_lib.py @@ -0,0 +1,172 @@ +import unittest +import tempfile +import os + +# Import the module under test +import lib.html_lib as html_lib + +class TestNormalizeBytes(unittest.TestCase): + + def test_bytes_only(self): + self.assertEqual(html_lib.normalize_bytes(932), "932 B") + + def test_kilobytes_binary(self): + self.assertEqual(html_lib.normalize_bytes(2048), "2 KB") + + def test_kilobytes_decimal(self): + self.assertEqual(html_lib.normalize_bytes(2000, si=True), "2 kB") + + def test_megabytes(self): + self.assertEqual(html_lib.normalize_bytes(5 * 1024 * 1024), "5 MB") + + def test_gigabytes(self): + self.assertEqual(html_lib.normalize_bytes(3 * 1024**3), "3 GB") + + def test_negative_bytes(self): + self.assertEqual(html_lib.normalize_bytes(-1024), "-1 KB") + + def test_precision(self): + self.assertEqual(html_lib.normalize_bytes(1536, precision=1), "1.5 KB") + + def test_type_error(self): + with self.assertRaises(TypeError): + html_lib.normalize_bytes("not a number") + + +class TestBuildHtmlMemUtilizationTable(unittest.TestCase): + + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False, mode='w+', encoding='utf-8') + self.filename = self.tmp_file.name + + def tearDown(self): + self.tmp_file.close() + os.remove(self.filename) + + def test_single_node_valid_input(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + self.assertIn("8192", content) + self.assertIn("10%", content) + + def test_multiple_nodes(self): + use_dict = { + f"node{i}": { + **{f"card{j}": { + "GPU Memory Allocated (VRAM%)": f"{j*10}%", + "GPU Memory Read/Write Activity (%)": f"{j*5}%", + "Memory Activity": f"{j*3}%", + "Avg. Memory Bandwidth": f"{j*2} GB/s" + } for j in range(8)} + } for i in range(2) + } + + amd_dict = { + f"node{i}": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] for i in range(2) + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("node0", content) + self.assertIn("node1", content) + + def test_rocm7_style_gpu_data(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": { + "gpu_data": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + + def test_missing_gpu_key_raises_keyerror(self): + use_dict = { + "node1": { + "card0": { + "GPU Memory Allocated (VRAM%)": "10%", + "GPU Memory Read/Write Activity (%)": "20%", + "Memory Activity": "30%", + "Avg. Memory Bandwidth": "40 GB/s" + } + # Missing card1 to card7 + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + with self.assertRaises(KeyError): + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/lib/unittests/test_verify_lib.py b/lib/unittests/test_verify_lib.py new file mode 100644 index 00000000..c7d1722f --- /dev/null +++ b/lib/unittests/test_verify_lib.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import MagicMock, patch + + +import lib.verify_lib as verify_lib + +class TestVerifyGpuPcieBusWidth(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_bus_width(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'}, + 'card1': {'PCI Bus': '0000:02:00.0'} + }, + 'node2': { + 'card0': {'PCI Bus': '0000:03:00.0'}, + 'card1': {'PCI Bus': '0000:04:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 32GT/s, Width x16', + 'node2': 'LnkSta: Speed 32GT/s, Width x16' + } + + result = verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=2) + self.assertEqual(result, {'node1': [], 'node2': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_invalid_bus_speed(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 16GT/s, Width x16' + } + + verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=1) + mock_fail_test.assert_called() + + +class TestVerifyGpuPcieErrors(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_error_metrics(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '10', + 'pcie_nak_sent_count_acc (Count)': '20', + 'pcie_nak_rcvd_count_acc (Count)': '30' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(result, {'node1': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_threshold_exceeded(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '101', + 'pcie_nak_sent_count_acc (Count)': '150', + 'pcie_nak_rcvd_count_acc (Count)': '200' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(len(result['node1']), 3) + mock_fail_test.assert_called() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/lib/utils_lib.py b/lib/utils_lib.py index bc8f0387..23951dc1 100644 --- a/lib/utils_lib.py +++ b/lib/utils_lib.py @@ -11,7 +11,7 @@ import json import pytest -import globals +from . import globals log = globals.log diff --git a/lib/verify_lib.py b/lib/verify_lib.py index 86eb4b74..ccaa806b 100644 --- a/lib/verify_lib.py +++ b/lib/verify_lib.py @@ -9,9 +9,9 @@ import re import sys -from utils_lib import * -from rocm_plib import * -import linux_utils +from .utils_lib import * +from .rocm_plib import * +from . import linux_utils err_patterns_dict = { diff --git a/run_all_unittests.py b/run_all_unittests.py new file mode 100644 index 00000000..77519ffb --- /dev/null +++ b/run_all_unittests.py @@ -0,0 +1,12 @@ +# run_all_unittests.py +import unittest + +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# Add all unit test directories +for test_dir in ['lib/unittests']: + suite.addTests(loader.discover(start_dir=test_dir)) + +runner = unittest.TextTestRunner(verbosity=2) +runner.run(suite) \ No newline at end of file From a286a5f1079cf90dc87f41f9ef51262b389d2b25 Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Wed, 19 Nov 2025 20:07:36 +0000 Subject: [PATCH 4/6] Coming up with makefile to execute UT in virtual env conveniently. Signed-off-by: Ignatious Johnson --- Makefile | 43 ++++++++++++++++++ ...max_benchmark.cpython-312-pytest-8.4.1.pyc | Bin 11337 -> 0 bytes 2 files changed, 43 insertions(+) create mode 100644 Makefile delete mode 100644 tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a1a3cf37 --- /dev/null +++ b/Makefile @@ -0,0 +1,43 @@ +VENV_DIR = test_venv +PYTHON = python +PIP = $(VENV_DIR)/bin/pip + +.PHONY: help venv build install test clean all clean_venv clean_build + +all: build venv install test + +help: + @echo "Available targets:" + @echo " venv - Create virtual environment" + @echo " install - Install requirements" + @echo " test - Test cvs list and cvs generate commands" + @echo " all - Run venv, install, and test" + @echo " clean - Remove virtual environment and Python cache files" + +venv: clean_venv + @echo "Creating virtual environment..." + $(PYTHON) -m venv $(VENV_DIR) + +install: venv build + @echo "Installing from built distribution..." + $(PIP) install -r requirements.txt + +test: install + @echo "Unit Testing cvs..." + $(VENV_DIR)/bin/python run_all_unittests.py + +clean_venv: + @echo "Removing virtual environment..." + @if [ -n "$$VIRTUAL_ENV" ] && [ "$$VIRTUAL_ENV" = "$$(pwd)/$(VENV_DIR)" ]; then \ + echo "ERROR: You are currently in the venv. Please run 'deactivate' first."; \ + exit 1; \ + fi + rm -rf $(VENV_DIR) + +clean_pycache: + @echo "Removing Python cache files..." + find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + find . -name "*.pyc" -delete 2>/dev/null || true + find . -name "*.pyo" -delete 2>/dev/null || true + +clean: clean_venv clean_pycache diff --git a/tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc b/tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc deleted file mode 100644 index bea64d7a4ae713eb8fb2f8829364c8450e90ad7c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11337 zcmd5?TWl0pny%`uzS`{z#y7&{5Ce7tZ9^cJfj}@2ZW|mlkk!N)rnlTxwySA(^;A{c zwA0KO4HC>Q(vZ#W!YUfASK0?6G7{Q-$*yL#lX=~jZrKr5#nMJfUPjsnPhb^^dD-tj zb?I)~Nzl&2PJvIII`{LR|Ns8W>Hkb5;sTxv`CjdBcL~B5^x-{{n|b_i0YSJcD8i_q zh)N(Qjt2NUI3;BQqA(R24U0}qWHiESqN7p%j*Z6nJ3bn3_--0);@{1q&HSAhP4G9$ z)mV13w}9F^yRDY@a_A9!fR)k1=a+JAzruP@F6G+YTp+h$bORr|W3+?6H;!&pH)eM` z{p!=v&VZmsvU}^xoV*hHDEvTl^N}0E1GmX&_cQf~*7fLlrXI1n9=*@hBVN~I(=+vG zs_XIGGxcb$>#_NndL-(4YbWSfMzU(=xJC zpI4Zot9&e`O%-ybhHBkla=yf7WDT8}KF%_Feq2*9iYDhYt27wDG^w(ZDjS#~lPl5^ zv^=v=IIbJ&gs$Z$>Xu36yfKYrsLm3Qr}7!yxZ-R!qbn*+WMSzTZ&Ee*9xoVr0Xrzs zGQO2yzaxi)1aHeES(if zmTFpqao&H}m@p6V9CIr;i6;)R<3%lJr9e%mJsa1qT1AjTy;{&sqH9Kz%VcbBVxLJ2KhL4bNwlNDB7_J)pE`OAI%Oq9V(l2((LKE zp-Fv89hz42S4>084`rs!AsW^k((>bAKTK82P03elA8CB0$IxJMN(TQzsv{VYC!{9K zP9$@}%2uJd=TX!4YSZ@l!G)&%6>0xtR0w}G8~*B>*pS3@;!m9mEdvtna563T+qki| zpFf%zYeUs|(OXH{VT(qdJkPPih@DuA=Y z&n|B^IWzF1F{0oBgNz=B4~=GI?TxDz<1}qU)qZNZmZMS&P5Ua+J|ltZ2964&fCJk@ zWbO)<4{K%Nft!C(xF%ZEhCf-)%<>BGF8jcDjZ7{4s~I{LF*|>L8t_KE7O-er{K+jHS0>I*G#VJy(No3o{}>cEscEH*^`u2{M!70p@t{<Swgml`wxE#%~kI;rO{vS|k~ zsx1Kz^<4;Q6POHJJssr%~$ zn()p~O=7a!UBk`6Pntu+oso=dhoe*wd=dC-e0-w_1inKAe3rcyRN#9kLQMpI8dOBu z;x(xz3}p#I9Z^35KTkJw>{+DXaSfN_itN;V>~$=oAdmnK!@B2&8lOp+mk zV8J1A4|Qqt^;McqB5m0%jU{FX&2To}>1?I`lw8K2dpcjn5|x(c7h-)x-=X=^@}{l- zko^1P{Mm&~gOx}(ozBqbojd0ztDQr0BTvFY*GmuczsW6hUYr|QYU-^>y>vQ5y!TTx zhJkN)9SQz>XgK^0XY*?``Tsea?|>@F12sU^dZ)O*c8YDj5s)s(hCF2^!>i8IAYanx zqNE$in!_`;pw#YMJq|y})yKA-Mxi8&KT}3BCp^=!ZQff+y;j|P42Oh++j69G5AEg;3peG!_P(5*EAad*4k{b`b5go2f{CXV4xm~bWQq+ME#uHj5y=Y zo{Y->g@wmS?tEQ#rZ_P{D&iDCmdP17d_a5sikfG>BIA}h*~@k2x$9neMuOv!O{}Bf z8U>#X5FkLuU1~JCJmpEXP&P8TickYkRGy{JSc>&ej;HzBINbL)YD&IJfEf8}p3XaS z%Gt~MN|XB|m!R3aW~AGwD||wwar^E7$JI%BTGI`0Q7JfHGVCx?k99{ES#^fHTmBL8$z9tn(Lg0zG@gp_k=gdw z4Fd~@%MPcQT&@1uwtl$Oj?X==7xT*CY{zh(snm|?nM~1ug>Z=bq6L))POh-cZc*RN(&?rs8GII#DDckRGvqg0^mQaw%prnL9^BbJWO!y&ao0^ z*CA}EtI zU}GZ2#O0YKm~@(_cV0_1q~YZeQq!ukWyOtVOZ)#yJil9Zda! ztc9}+B*G5Zj2COi$SKM9_~W9MxdJl5tmkTD0GoDsZE469pJ6!k9Yd{$cNuL}h=@li zy|~pR5U)P8vzKB!G9!-#?CTvNVwNsL!*e!N&K@3X^N@_=R(nw*7kOj;I_;>(8Fr%K z+>-%YqAFV|;G91U8A%8p93b$SEfE*n677c@Xt2deEbUl+QpxcnK$iHLw+@@>0?JJi zSt6+8=N!hDLCN3Y@9|kAfYIGTd&k__Whr)R@1oSb+_v$~59jw(I|mYd!lvQQR_yhzKttaVcng zj|vVUft#}))s9j zK(N)Zc8IUJ0aG8r6;Z9@=XPqCS9HK86z>_AWbQ*ixQ;GF6K%47?3uPMzWIBZE+Q32}6WtBV}2F}NHP z)6I4ndS)tRPH8M9vodZUV3DQ{?1M?g0VAkrh8>&K;gLJgp}I!#12GZz%Z^SIifLjLlYAChng=qc@oi92#h({U zVL~W=YuAnQfSKEUx5oc!&!W_|)YkQ=jaA#&e0#O6Z|>|;NAKOk?;WmedwHSbmAN;T zS~lJLUbSWC-09`Sp8Ka45(nqba06?t^byK&=~Iw_Yli?JZ3AIGU*fX$(uz8I$_Vb& zoU^*NG~TMG-q6AwSy1l5gpTy^tCBX|FBOrSsb zeUZzF(|WE*Ds4=KkU{z=n7@uOa=5H|B17t- zLpC-rW>ADTL#JFUd_y>t8xy$j!GTVZ_0iCW-&z5YFnby&NzpQw9nt`%v^)7>1d#cq z6;|LL6?5*7)cmw&==muy7r#$M;Em%EzkUKj(H&|2H47HWlnyhsm^W0XL2aTTcaj=4 zPIwktUWRKo!wo}ci|%Ctj>)mB>GG4TZCFThiO$gtA>)>dQ;YD%SMi`QqHjvefcO+EJBu%#YvCFG@%G_r9vM@8Pabr9*2WhoUYVa()Ylfx^e9 z^Q;#ys$I;A!h~S?@RFtx_yfptRCb^izsdTM9j}!GS!l~WGQ1^%IVD7?{}pcditrXu zCH7JLfqUlxLxk)X{$yLcTZ8pTEZMOdscqk*|jyES4ac9cTde^Jx<*POLF%4oryVd93v^FJA zsCHYe06E+k+D!`osf447n;ZdjlxY+a?c{<7UpcKk1Yf`dh?kNfZ~80To;-OdPWWmuv$;};A$6<17TjTpx%pe!ay+y#4LNRB-m}wMsI|@=bE)%~=Wd)Pv!t`n1BEXM6n0gmUH8Ke zTbANImEPUe_#U3wQ;qMvf4Lez{7qk@)%Yt*9X<0q|EDPw+Wn+O2*qxmyMFGLaRXQr ztHh5jO2bR7oxgbZ=kMM-ywKVYkIIfzV0umbBd3{)^JjOZR?M*}V7u z)sLkQfAVYNKdgVXK5cz{xw*5_bNJ(vAHP)TI{s<%iIos_LAX5B{i)RBz?C0N#%_JM zigLID2%*q<(yp)zg5`&^v-N%5cCeOUSCK%LL*;-H_$cTB%C&GgoQ-+o)!sD$Lw!@N z+{-o8`;g?-3jXiL;1Pwht+jG5*U;YwDzCP_zvaVKtxpLK;K*$AM8QhoK0LMmg?(eE z@Z+0K1>UxyjTK?)r+qd`#*H6fVB>8hcDrB9e_8=+%WkbliF5jxy>$#fjZxD4w@^lH zoOD(nOJ}AOdlR&iVbxW*Em6O=E|b%s^r^4CrGUUl8umq9S446OuqyCdD!T_)sQ`(( zN=KSZL&nOJ5$?x2g2nB0u1L4j&M3NShqbA+K9)7f!F0?-CvsqK=9(3>N`qEu&P@>pYm!Y+o9H!Ys56FUsdmgSvLjwu9G&^LWy20TS%a=1 z?VyQBOh7k{cksn8L+uFMN(y-)a)YZ7&@oRl*lk`EG>xzgndNp2zj@1ArLS8AQM@7iEb)a9`%H*`=B7=b3HaZF^m8G2GkQIGv-NuGo%TDY7lkcf z3R{YgqEdn;^vj>SLT8%Q6bPa*L8F2^{sa%7sYMA544CcJPry0wnDj; z=hs&6_pA4>46Ls15PRHe1kURhn; oBW`f3kz0A`+pAw)UES#qj@(xVJH^9az0fK4K0eqa_HuCkUzRg#ga7~l From b4028790d83b53867b5d931c4158dfd033974e6f Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Wed, 19 Nov 2025 21:21:02 +0000 Subject: [PATCH 5/6] Unittest for parallel ssh lib covers exec and exec_cmd_list methods Signed-off-by: Ignatious Johnson --- lib/unittests/test_parallel_ssh_lib.py | 146 +++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 lib/unittests/test_parallel_ssh_lib.py diff --git a/lib/unittests/test_parallel_ssh_lib.py b/lib/unittests/test_parallel_ssh_lib.py new file mode 100644 index 00000000..fb094c5f --- /dev/null +++ b/lib/unittests/test_parallel_ssh_lib.py @@ -0,0 +1,146 @@ +import unittest +from unittest.mock import patch, MagicMock +from lib.parallel_ssh_lib import Pssh + +class TestPsshExec(unittest.TestCase): + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def setUp(self, mock_pssh_client): + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + + def test_exec_successful(self): + # Test: Execute command successfully on all hosts + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['output1 line1', 'output1 line2'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = ['output2 line1'] + mock_output2.stderr = [] + mock_output2.exception = None + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec('echo hello') + + self.mock_client.run_command.assert_called_once_with('echo hello', stop_on_errors=True) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('output1 line1', result['host1']) + self.assertIn('output2 line1', result['host2']) + + def test_exec_with_exception_stop_on_errors_true(self): + # Test: Handle exceptions with stop_on_errors=True (default) + # Exception should be raised, and no result returned (no partial results) + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Connection failed') + + # With stop_on_errors=True, run_command raises on exception, no result returned + with self.assertRaises(Timeout) as cm: + result = self.pssh.exec('echo hello') # This should raise, so result is not assigned + + self.assertIn('Connection failed', str(cm.exception)) + # Since exception was raised, result was not returned + self.assertNotIn('result', locals()) + + def test_exec_with_exception_stop_on_errors_false(self): + # Test Case 2.2: Execute command with timeout and stop_on_errors=False + # Exception should not be raised instead populated in output for failed hosts, success for others + self.pssh.stop_on_errors = False + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec('echo hello', timeout=10) + + self.mock_client.run_command.assert_called_once_with('echo hello', read_timeout=10, stop_on_errors=False) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Command timed out', result['host2']) + + def test_exec_cmd_list_successful(self): + # Test: Execute different commands on different hosts successfully + cmd_list = ['echo host1', 'echo host2'] + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['host1'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = ['host2'] + mock_output2.stderr = [] + mock_output2.exception = None + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec_cmd_list(cmd_list) + + self.mock_client.run_command.assert_called_once_with('%s', host_args=cmd_list, stop_on_errors=True) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host1', result['host1']) + self.assertIn('host2', result['host2']) + + def test_exec_cmd_list_with_exception_stop_on_errors_false(self): + # Test: Handle exceptions with stop_on_errors=False for exec_cmd_list + # Exception should not be raised instead populated in output for failed hosts, success for others + self.pssh.stop_on_errors = False + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.mock_client.run_command.assert_called_once_with('%s', host_args=cmd_list, read_timeout=10, stop_on_errors=False) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Command timed out', result['host2']) + + def test_exec_cmd_list_with_exception_stop_on_errors_true(self): + # Test: Handle exceptions with stop_on_errors=True for exec_cmd_list + # Exception should be raised, and no result returned (no partial results) + cmd_list = ['echo test'] + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout) as cm: + result = self.pssh.exec_cmd_list(cmd_list, timeout=5) + + self.assertIn('Command timed out', str(cm.exception)) + self.assertNotIn('result', locals()) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 1192c906292ba36bebe44b7ee60a8179c2c3b3a9 Mon Sep 17 00:00:00 2001 From: Ignatious Johnson Date: Thu, 20 Nov 2025 00:07:25 +0000 Subject: [PATCH 6/6] Support to prune unreachable nodes in case of pssh.exceptions.Timeout exception and the node is unreachable. Unreachability is ensured by creating a ssh session to the specific set of nodes which raised Timeout. Added UT to cover these cases Signed-off-by: Ignatious Johnson --- Makefile | 9 +- lib/html_lib.py | 2 +- lib/linux_utils.py | 4 +- lib/parallel_ssh_lib.py | 74 ++++- lib/rocm_plib.py | 4 +- lib/unittests/test_html_lib.py | 342 +++++++++++------------ lib/unittests/test_parallel_ssh_lib.py | 365 ++++++++++++++++++++++++- lib/unittests/test_verify_lib.py | 180 ++++++------ lib/utils_lib.py | 2 +- lib/verify_lib.py | 6 +- run_all_unittests.py | 35 ++- 11 files changed, 725 insertions(+), 298 deletions(-) diff --git a/Makefile b/Makefile index a1a3cf37..f89e2a7e 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,9 @@ VENV_DIR = test_venv -PYTHON = python +PYTHON := $(shell command -v python3 || command -v python) PIP = $(VENV_DIR)/bin/pip -.PHONY: help venv build install test clean all clean_venv clean_build - -all: build venv install test +.PHONY: help venv install test clean all clean_venv +all: venv install test help: @echo "Available targets:" @@ -18,7 +17,7 @@ venv: clean_venv @echo "Creating virtual environment..." $(PYTHON) -m venv $(VENV_DIR) -install: venv build +install: venv @echo "Installing from built distribution..." $(PIP) install -r requirements.txt diff --git a/lib/html_lib.py b/lib/html_lib.py index f6078a8a..e24ac1c7 100644 --- a/lib/html_lib.py +++ b/lib/html_lib.py @@ -10,7 +10,7 @@ import re import json -from .rocm_plib import * +from rocm_plib import * def build_html_page_header(filename): diff --git a/lib/linux_utils.py b/lib/linux_utils.py index 5144f167..4ab6f214 100644 --- a/lib/linux_utils.py +++ b/lib/linux_utils.py @@ -9,9 +9,9 @@ import sys import os import json -from . import rocm_plib +import rocm_plib -from .utils_lib import * +from utils_lib import * diff --git a/lib/parallel_ssh_lib.py b/lib/parallel_ssh_lib.py index 870e4825..5228ca2e 100755 --- a/lib/parallel_ssh_lib.py +++ b/lib/parallel_ssh_lib.py @@ -7,12 +7,14 @@ from __future__ import print_function from pssh.clients import ParallelSSHClient +from pssh.exceptions import Timeout import sys import os import re import ast import json +import time # Following used only for scp of file import paramiko @@ -30,21 +32,69 @@ class Pssh(): def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False, stop_on_errors=True ): - self.log = log - self.host_list = host_list - self.user = user - self.pkey = pkey - self.password = password + self.log = log + self.host_list = host_list + self.reachable_hosts = host_list + self.user = user + self.pkey = pkey + self.password = password self.host_key_check = host_key_check self.stop_on_errors = stop_on_errors + self.unreachable_hosts = [] if self.password is None: - print(self.host_list) + print(self.reachable_hosts) print(self.user) print(self.pkey) - self.client = ParallelSSHClient( self.host_list, user=self.user, pkey=self.pkey, keepalive_seconds=30 ) + self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30 ) else: - self.client = ParallelSSHClient( self.host_list, user=self.user, password=self.password, keepalive_seconds=30 ) + self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30 ) + + + def check_connectivity(self, host): + """ + Check if a host is reachable via SSH by attempting a simple command. + """ + try: + temp_client = ParallelSSHClient([host], user=self.user, pkey=self.pkey if self.password is None else None, password=self.password, keepalive_seconds=30) + temp_client.run_command('echo 1', stop_on_errors=True, read_timeout=5) + return True + except Exception: + return False + + def prune_unreachable_hosts(self, output): + """ + Prune unreachable hosts from self.reachable_hosts if they have Timeout exceptions and also fails connectivity check. + + Targeted pruning: Only Timeout exceptions trigger pruning to avoid removing hosts for transient failures + like authentication errors or SSH protocol issues, which may succeed on next try. Timeouts are indicative + of potential unreachability, so we perform an additional connectivity check before pruning. This ensures + that hosts are not permanently removed from the list for recoverable errors. + """ + initial_unreachable_len = len(self.unreachable_hosts) + # Only prune on Timeout exceptions to avoid removing hosts for transient issues like auth failures. + # Timeouts indicate potential unreachability, so we check connectivity and prune if confirmed. + failed_hosts = [item.host for item in output if item.exception and isinstance(item.exception, Timeout)] + for host in failed_hosts: + if not self.check_connectivity(host): + print(f"Host {host} is unreachable, pruning from reachable hosts list.") + self.unreachable_hosts.append(host) + self.reachable_hosts.remove(host) + if len(self.unreachable_hosts) > initial_unreachable_len: + # Recreate client with filtered reachable_hosts, only if hosts were actually pruned + if self.password is None: + self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30) + else: + self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30) + + + def inform_unreachability(self, cmd_output): + """ + Update cmd_output with "Host Unreachable" for all hosts in self.unreachable_hosts. + This ensures that the output dictionary reflects the status of pruned hosts. + """ + for host in self.unreachable_hosts: + cmd_output[host] += "\nABORT: Host Unreachable Error" def exec(self, cmd, timeout=None ): @@ -78,6 +128,10 @@ def exec(self, cmd, timeout=None ): cmd_out_str += exc_str + '\n' cmd_output[item.host] = cmd_out_str + if not self.stop_on_errors: + self.prune_unreachable_hosts(output) + self.inform_unreachability(cmd_output) + return cmd_output @@ -117,6 +171,10 @@ def exec_cmd_list(self, cmd_list, timeout=None ): i=i+1 cmd_output[item.host] = cmd_out_str + if not self.stop_on_errors: + self.prune_unreachable_hosts(output) + self.inform_unreachability(cmd_output) + return cmd_output diff --git a/lib/rocm_plib.py b/lib/rocm_plib.py index a73aee1c..ee57ae7b 100644 --- a/lib/rocm_plib.py +++ b/lib/rocm_plib.py @@ -8,9 +8,9 @@ import re import os import sys -from . import parallel_ssh_lib +import parallel_ssh_lib -from .utils_lib import * +from utils_lib import * def get_rocm_smi_dict( phdl ): diff --git a/lib/unittests/test_html_lib.py b/lib/unittests/test_html_lib.py index e259966a..fffaf6c1 100644 --- a/lib/unittests/test_html_lib.py +++ b/lib/unittests/test_html_lib.py @@ -1,172 +1,172 @@ -import unittest -import tempfile -import os - -# Import the module under test -import lib.html_lib as html_lib - -class TestNormalizeBytes(unittest.TestCase): - - def test_bytes_only(self): - self.assertEqual(html_lib.normalize_bytes(932), "932 B") - - def test_kilobytes_binary(self): - self.assertEqual(html_lib.normalize_bytes(2048), "2 KB") - - def test_kilobytes_decimal(self): - self.assertEqual(html_lib.normalize_bytes(2000, si=True), "2 kB") - - def test_megabytes(self): - self.assertEqual(html_lib.normalize_bytes(5 * 1024 * 1024), "5 MB") - - def test_gigabytes(self): - self.assertEqual(html_lib.normalize_bytes(3 * 1024**3), "3 GB") - - def test_negative_bytes(self): - self.assertEqual(html_lib.normalize_bytes(-1024), "-1 KB") - - def test_precision(self): - self.assertEqual(html_lib.normalize_bytes(1536, precision=1), "1.5 KB") - - def test_type_error(self): - with self.assertRaises(TypeError): - html_lib.normalize_bytes("not a number") - - -class TestBuildHtmlMemUtilizationTable(unittest.TestCase): - - def setUp(self): - self.tmp_file = tempfile.NamedTemporaryFile(delete=False, mode='w+', encoding='utf-8') - self.filename = self.tmp_file.name - - def tearDown(self): - self.tmp_file.close() - os.remove(self.filename) - - def test_single_node_valid_input(self): - use_dict = { - "node1": { - **{f"card{i}": { - "GPU Memory Allocated (VRAM%)": f"{i*10}%", - "GPU Memory Read/Write Activity (%)": f"{i*5}%", - "Memory Activity": f"{i*3}%", - "Avg. Memory Bandwidth": f"{i*2} GB/s" - } for i in range(8)} - } - } - - amd_dict = { - "node1": [ - { - "mem_usage": { - "total_vram": {"value": "16384"}, - "used_vram": {"value": "8192"}, - "free_vram": {"value": "8192"} - } - } for _ in range(8) - ] - } - - html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) - with open(self.filename, 'r', encoding='utf-8') as f: - content = f.read() - self.assertIn("GPU Memory Utilization", content) - self.assertIn("G0 Tot VRAM MB", content) - self.assertIn("node1", content) - self.assertIn("8192", content) - self.assertIn("10%", content) - - def test_multiple_nodes(self): - use_dict = { - f"node{i}": { - **{f"card{j}": { - "GPU Memory Allocated (VRAM%)": f"{j*10}%", - "GPU Memory Read/Write Activity (%)": f"{j*5}%", - "Memory Activity": f"{j*3}%", - "Avg. Memory Bandwidth": f"{j*2} GB/s" - } for j in range(8)} - } for i in range(2) - } - - amd_dict = { - f"node{i}": [ - { - "mem_usage": { - "total_vram": {"value": "16384"}, - "used_vram": {"value": "8192"}, - "free_vram": {"value": "8192"} - } - } for _ in range(8) - ] for i in range(2) - } - - html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) - with open(self.filename, 'r', encoding='utf-8') as f: - content = f.read() - self.assertIn("node0", content) - self.assertIn("node1", content) - - def test_rocm7_style_gpu_data(self): - use_dict = { - "node1": { - **{f"card{i}": { - "GPU Memory Allocated (VRAM%)": f"{i*10}%", - "GPU Memory Read/Write Activity (%)": f"{i*5}%", - "Memory Activity": f"{i*3}%", - "Avg. Memory Bandwidth": f"{i*2} GB/s" - } for i in range(8)} - } - } - - amd_dict = { - "node1": { - "gpu_data": [ - { - "mem_usage": { - "total_vram": {"value": "16384"}, - "used_vram": {"value": "8192"}, - "free_vram": {"value": "8192"} - } - } for _ in range(8) - ] - } - } - - html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) - with open(self.filename, 'r', encoding='utf-8') as f: - content = f.read() - self.assertIn("GPU Memory Utilization", content) - self.assertIn("G0 Tot VRAM MB", content) - self.assertIn("node1", content) - - def test_missing_gpu_key_raises_keyerror(self): - use_dict = { - "node1": { - "card0": { - "GPU Memory Allocated (VRAM%)": "10%", - "GPU Memory Read/Write Activity (%)": "20%", - "Memory Activity": "30%", - "Avg. Memory Bandwidth": "40 GB/s" - } - # Missing card1 to card7 - } - } - - amd_dict = { - "node1": [ - { - "mem_usage": { - "total_vram": {"value": "16384"}, - "used_vram": {"value": "8192"}, - "free_vram": {"value": "8192"} - } - } for _ in range(8) - ] - } - - with self.assertRaises(KeyError): - html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) - - -if __name__ == '__main__': +import unittest +import tempfile +import os + +# Import the module under test +import lib.html_lib as html_lib + +class TestNormalizeBytes(unittest.TestCase): + + def test_bytes_only(self): + self.assertEqual(html_lib.normalize_bytes(932), "932 B") + + def test_kilobytes_binary(self): + self.assertEqual(html_lib.normalize_bytes(2048), "2 KB") + + def test_kilobytes_decimal(self): + self.assertEqual(html_lib.normalize_bytes(2000, si=True), "2 kB") + + def test_megabytes(self): + self.assertEqual(html_lib.normalize_bytes(5 * 1024 * 1024), "5 MB") + + def test_gigabytes(self): + self.assertEqual(html_lib.normalize_bytes(3 * 1024**3), "3 GB") + + def test_negative_bytes(self): + self.assertEqual(html_lib.normalize_bytes(-1024), "-1 KB") + + def test_precision(self): + self.assertEqual(html_lib.normalize_bytes(1536, precision=1), "1.5 KB") + + def test_type_error(self): + with self.assertRaises(TypeError): + html_lib.normalize_bytes("not a number") + + +class TestBuildHtmlMemUtilizationTable(unittest.TestCase): + + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False, mode='w+', encoding='utf-8') + self.filename = self.tmp_file.name + + def tearDown(self): + self.tmp_file.close() + os.remove(self.filename) + + def test_single_node_valid_input(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + self.assertIn("8192", content) + self.assertIn("10%", content) + + def test_multiple_nodes(self): + use_dict = { + f"node{i}": { + **{f"card{j}": { + "GPU Memory Allocated (VRAM%)": f"{j*10}%", + "GPU Memory Read/Write Activity (%)": f"{j*5}%", + "Memory Activity": f"{j*3}%", + "Avg. Memory Bandwidth": f"{j*2} GB/s" + } for j in range(8)} + } for i in range(2) + } + + amd_dict = { + f"node{i}": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] for i in range(2) + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("node0", content) + self.assertIn("node1", content) + + def test_rocm7_style_gpu_data(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": { + "gpu_data": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + + def test_missing_gpu_key_raises_keyerror(self): + use_dict = { + "node1": { + "card0": { + "GPU Memory Allocated (VRAM%)": "10%", + "GPU Memory Read/Write Activity (%)": "20%", + "Memory Activity": "30%", + "Avg. Memory Bandwidth": "40 GB/s" + } + # Missing card1 to card7 + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + with self.assertRaises(KeyError): + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + + +if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/lib/unittests/test_parallel_ssh_lib.py b/lib/unittests/test_parallel_ssh_lib.py index fb094c5f..42140cc2 100644 --- a/lib/unittests/test_parallel_ssh_lib.py +++ b/lib/unittests/test_parallel_ssh_lib.py @@ -8,6 +8,7 @@ class TestPsshExec(unittest.TestCase): def setUp(self, mock_pssh_client): self.mock_client = MagicMock() mock_pssh_client.return_value = self.mock_client + self.mock_pssh_client = mock_pssh_client self.host_list = ['host1', 'host2'] self.pssh = Pssh('log', self.host_list, user='user', password='pass') @@ -49,7 +50,8 @@ def test_exec_with_exception_stop_on_errors_true(self): # Since exception was raised, result was not returned self.assertNotIn('result', locals()) - def test_exec_with_exception_stop_on_errors_false(self): + @patch.object(Pssh, 'check_connectivity') + def test_exec_with_exception_stop_on_errors_false(self, mock_check_connectivity): # Test Case 2.2: Execute command with timeout and stop_on_errors=False # Exception should not be raised instead populated in output for failed hosts, success for others self.pssh.stop_on_errors = False @@ -67,6 +69,7 @@ def test_exec_with_exception_stop_on_errors_false(self): mock_output2.exception = Timeout('Command timed out') self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # No pruning result = self.pssh.exec('echo hello', timeout=10) @@ -76,6 +79,187 @@ def test_exec_with_exception_stop_on_errors_false(self): self.assertIn('success output', result['host1']) self.assertIn('Command timed out', result['host2']) + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_with_pruning_unreachable_host(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, timeout on host2, and check_connectivity fails for host2, prune it + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = False # Simulate unreachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(self.pssh.unreachable_hosts, ['host2']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + # Client should be recreated once (init + prune) + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_no_pruning_when_reachable(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, timeout on host2, but check_connectivity succeeds, no pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Always reachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No change + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Command timed out', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_pruning_with_multiple_unreachable_hosts(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, multiple hosts (host2, host3) timeout and are unreachable, prune all + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2', 'host3'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + mock_output3 = MagicMock() + mock_output3.host = 'host3' + mock_output3.stdout = [] + mock_output3.stderr = [] + mock_output3.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2, mock_output3] + mock_check_connectivity.return_value = False # Simulate all unreachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(sorted(self.pssh.unreachable_hosts), ['host2', 'host3']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host3', result) + self.assertIn('success output', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(result['host3'], 'Command timed out\n\nABORT: Host Unreachable Error') + # Client should be recreated once (init + prune) + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def test_exec_no_pruning_on_non_timeout_exception(self, mock_pssh_client): + # Test: With stop_on_errors=False, non-Timeout exception on host2, no pruning occurs + # Non-Timeout exceptions are treated as transient failures (e.g., authentication or network glitches), + # not indicative of permanent unreachability, so we don't attempt connectivity checks or pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + from pssh.exceptions import ConnectionError + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = ConnectionError('Connection failed') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No pruning + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Connection failed', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch.object(Pssh, 'prune_unreachable_hosts') + @patch.object(Pssh, 'inform_unreachability') + def test_exec_no_pruning_when_stop_on_errors_true(self, mock_inform, mock_prune): + # Test: With stop_on_errors=True, no pruning even with timeout + # Since stop_on_errors=True, run_command raises immediately, so prune_unreachable_hosts and inform_unreachability are not invoked + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout): + self.pssh.exec('echo hello', timeout=10) + + # Assert that pruning methods were not called + mock_prune.assert_not_called() + mock_inform.assert_not_called() + + +class TestPsshExecCmdList(unittest.TestCase): + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def setUp(self, mock_pssh_client): + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.mock_pssh_client = mock_pssh_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + def test_exec_cmd_list_successful(self): # Test: Execute different commands on different hosts successfully cmd_list = ['echo host1', 'echo host2'] @@ -101,7 +285,8 @@ def test_exec_cmd_list_successful(self): self.assertIn('host1', result['host1']) self.assertIn('host2', result['host2']) - def test_exec_cmd_list_with_exception_stop_on_errors_false(self): + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_with_exception_stop_on_errors_false(self, mock_check_connectivity): # Test: Handle exceptions with stop_on_errors=False for exec_cmd_list # Exception should not be raised instead populated in output for failed hosts, success for others self.pssh.stop_on_errors = False @@ -120,6 +305,7 @@ def test_exec_cmd_list_with_exception_stop_on_errors_false(self): mock_output2.exception = Timeout('Command timed out') self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Simulate reachable, no pruning result = self.pssh.exec_cmd_list(cmd_list, timeout=10) @@ -142,5 +328,178 @@ def test_exec_cmd_list_with_exception_stop_on_errors_true(self): self.assertIn('Command timed out', str(cm.exception)) self.assertNotIn('result', locals()) + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_no_pruning_when_reachable(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with stop_on_errors=False, timeout on host2, but check_connectivity succeeds, no pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Always reachable + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No change + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Command timed out', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_with_pruning(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = False + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(self.pssh.unreachable_hosts, ['host2']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_pruning_with_multiple_unreachable_hosts(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with pruning for multiple unreachable hosts + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2', 'host3'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail1', 'echo fail2'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + mock_output3 = MagicMock() + mock_output3.host = 'host3' + mock_output3.stdout = [] + mock_output3.stderr = [] + mock_output3.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2, mock_output3] + mock_check_connectivity.return_value = False + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(sorted(self.pssh.unreachable_hosts), ['host2', 'host3']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host3', result) + self.assertIn('success', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(result['host3'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def test_exec_cmd_list_no_pruning_on_non_timeout_exception(self, mock_pssh_client): + # Test: exec_cmd_list with non-Timeout exception, no pruning occurs + # Non-Timeout exceptions are treated as transient failures, so no connectivity checks or pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import ConnectionError + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = ConnectionError('Connection failed') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No pruning + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Connection failed', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch.object(Pssh, 'prune_unreachable_hosts') + @patch.object(Pssh, 'inform_unreachability') + def test_exec_cmd_list_no_pruning_when_stop_on_errors_true(self, mock_inform, mock_prune): + # Test: exec_cmd_list with stop_on_errors=True, no pruning even with timeout + # Since stop_on_errors=True, run_command raises immediately, so prune_unreachable_hosts and inform_unreachability are not invoked + cmd_list = ['echo test'] + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout): + self.pssh.exec_cmd_list(cmd_list, timeout=5) + + # Assert that pruning methods were not called + mock_prune.assert_not_called() + mock_inform.assert_not_called() + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/lib/unittests/test_verify_lib.py b/lib/unittests/test_verify_lib.py index c7d1722f..17f7ea87 100644 --- a/lib/unittests/test_verify_lib.py +++ b/lib/unittests/test_verify_lib.py @@ -1,91 +1,91 @@ -import unittest -from unittest.mock import MagicMock, patch - - -import lib.verify_lib as verify_lib - -class TestVerifyGpuPcieBusWidth(unittest.TestCase): - - @patch('lib.verify_lib.get_gpu_pcie_bus_dict') - @patch('lib.verify_lib.fail_test') - def test_valid_bus_width(self, mock_fail_test, mock_get_bus_dict): - mock_get_bus_dict.return_value = { - 'node1': { - 'card0': {'PCI Bus': '0000:01:00.0'}, - 'card1': {'PCI Bus': '0000:02:00.0'} - }, - 'node2': { - 'card0': {'PCI Bus': '0000:03:00.0'}, - 'card1': {'PCI Bus': '0000:04:00.0'} - } - } - - phdl = MagicMock() - phdl.exec_cmd_list.return_value = { - 'node1': 'LnkSta: Speed 32GT/s, Width x16', - 'node2': 'LnkSta: Speed 32GT/s, Width x16' - } - - result = verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=2) - self.assertEqual(result, {'node1': [], 'node2': []}) - mock_fail_test.assert_not_called() - - @patch('lib.verify_lib.get_gpu_pcie_bus_dict') - @patch('lib.verify_lib.fail_test') - def test_invalid_bus_speed(self, mock_fail_test, mock_get_bus_dict): - mock_get_bus_dict.return_value = { - 'node1': { - 'card0': {'PCI Bus': '0000:01:00.0'} - } - } - - phdl = MagicMock() - phdl.exec_cmd_list.return_value = { - 'node1': 'LnkSta: Speed 16GT/s, Width x16' - } - - verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=1) - mock_fail_test.assert_called() - - -class TestVerifyGpuPcieErrors(unittest.TestCase): - - @patch('lib.verify_lib.get_gpu_metrics_dict') - @patch('lib.verify_lib.fail_test') - def test_valid_error_metrics(self, mock_fail_test, mock_get_metrics): - mock_get_metrics.return_value = { - 'node1': { - 'card0': { - 'pcie_l0_to_recov_count_acc (Count)': '10', - 'pcie_nak_sent_count_acc (Count)': '20', - 'pcie_nak_rcvd_count_acc (Count)': '30' - } - } - } - - phdl = MagicMock() - result = verify_lib.verify_gpu_pcie_errors(phdl) - self.assertEqual(result, {'node1': []}) - mock_fail_test.assert_not_called() - - @patch('lib.verify_lib.get_gpu_metrics_dict') - @patch('lib.verify_lib.fail_test') - def test_threshold_exceeded(self, mock_fail_test, mock_get_metrics): - mock_get_metrics.return_value = { - 'node1': { - 'card0': { - 'pcie_l0_to_recov_count_acc (Count)': '101', - 'pcie_nak_sent_count_acc (Count)': '150', - 'pcie_nak_rcvd_count_acc (Count)': '200' - } - } - } - - phdl = MagicMock() - result = verify_lib.verify_gpu_pcie_errors(phdl) - self.assertEqual(len(result['node1']), 3) - mock_fail_test.assert_called() - - -if __name__ == '__main__': +import unittest +from unittest.mock import MagicMock, patch + + +import lib.verify_lib as verify_lib + +class TestVerifyGpuPcieBusWidth(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_bus_width(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'}, + 'card1': {'PCI Bus': '0000:02:00.0'} + }, + 'node2': { + 'card0': {'PCI Bus': '0000:03:00.0'}, + 'card1': {'PCI Bus': '0000:04:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 32GT/s, Width x16', + 'node2': 'LnkSta: Speed 32GT/s, Width x16' + } + + result = verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=2) + self.assertEqual(result, {'node1': [], 'node2': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_invalid_bus_speed(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 16GT/s, Width x16' + } + + verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=1) + mock_fail_test.assert_called() + + +class TestVerifyGpuPcieErrors(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_error_metrics(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '10', + 'pcie_nak_sent_count_acc (Count)': '20', + 'pcie_nak_rcvd_count_acc (Count)': '30' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(result, {'node1': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_threshold_exceeded(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '101', + 'pcie_nak_sent_count_acc (Count)': '150', + 'pcie_nak_rcvd_count_acc (Count)': '200' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(len(result['node1']), 3) + mock_fail_test.assert_called() + + +if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/lib/utils_lib.py b/lib/utils_lib.py index 23951dc1..bc8f0387 100644 --- a/lib/utils_lib.py +++ b/lib/utils_lib.py @@ -11,7 +11,7 @@ import json import pytest -from . import globals +import globals log = globals.log diff --git a/lib/verify_lib.py b/lib/verify_lib.py index ccaa806b..86eb4b74 100644 --- a/lib/verify_lib.py +++ b/lib/verify_lib.py @@ -9,9 +9,9 @@ import re import sys -from .utils_lib import * -from .rocm_plib import * -from . import linux_utils +from utils_lib import * +from rocm_plib import * +import linux_utils err_patterns_dict = { diff --git a/run_all_unittests.py b/run_all_unittests.py index 77519ffb..7964218a 100644 --- a/run_all_unittests.py +++ b/run_all_unittests.py @@ -1,12 +1,23 @@ -# run_all_unittests.py -import unittest - -loader = unittest.TestLoader() -suite = unittest.TestSuite() - -# Add all unit test directories -for test_dir in ['lib/unittests']: - suite.addTests(loader.discover(start_dir=test_dir)) - -runner = unittest.TextTestRunner(verbosity=2) -runner.run(suite) \ No newline at end of file +# run_all_unittests.py +import sys +import os + +# Add lib directory to sys.path for absolute imports +lib_path = os.path.join(os.path.dirname(__file__), 'lib') +sys.path.insert(0, lib_path) + +import unittest + +def main(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add all unit test directories + for test_dir in ['lib/unittests']: + suite.addTests(loader.discover(start_dir=test_dir)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + +if __name__ == '__main__': + main() \ No newline at end of file