diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..f89e2a7e --- /dev/null +++ b/Makefile @@ -0,0 +1,42 @@ +VENV_DIR = test_venv +PYTHON := $(shell command -v python3 || command -v python) +PIP = $(VENV_DIR)/bin/pip + +.PHONY: help venv install test clean all clean_venv +all: venv install test + +help: + @echo "Available targets:" + @echo " venv - Create virtual environment" + @echo " install - Install requirements" + @echo " test - Test cvs list and cvs generate commands" + @echo " all - Run venv, install, and test" + @echo " clean - Remove virtual environment and Python cache files" + +venv: clean_venv + @echo "Creating virtual environment..." + $(PYTHON) -m venv $(VENV_DIR) + +install: venv + @echo "Installing from built distribution..." + $(PIP) install -r requirements.txt + +test: install + @echo "Unit Testing cvs..." + $(VENV_DIR)/bin/python run_all_unittests.py + +clean_venv: + @echo "Removing virtual environment..." + @if [ -n "$$VIRTUAL_ENV" ] && [ "$$VIRTUAL_ENV" = "$$(pwd)/$(VENV_DIR)" ]; then \ + echo "ERROR: You are currently in the venv. Please run 'deactivate' first."; \ + exit 1; \ + fi + rm -rf $(VENV_DIR) + +clean_pycache: + @echo "Removing Python cache files..." + find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + find . -name "*.pyc" -delete 2>/dev/null || true + find . -name "*.pyo" -delete 2>/dev/null || true + +clean: clean_venv clean_pycache diff --git a/UNIT_TESTING_GUIDE.md b/UNIT_TESTING_GUIDE.md new file mode 100644 index 00000000..511ca1b5 --- /dev/null +++ b/UNIT_TESTING_GUIDE.md @@ -0,0 +1,92 @@ + +# Unit Test Organization + +This project separates **unit tests** and **cluster validation tests** to maintain clarity and modularity. + +--- + +## ๐Ÿ“ Directory Structure + +``` +cvs/ +โ”œโ”€โ”€ lib/ +โ”‚ โ”œโ”€โ”€ docker_lib.py +โ”‚ โ”œโ”€โ”€ utils_lib.py +โ”‚ โ”œโ”€โ”€ verify_lib.py +โ”‚ โ””โ”€โ”€ unittests/ # Unit tests for lib modules +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ test_verify_lib.py +โ”‚ โ””โ”€โ”€ test_html_lib.py +โ”œโ”€โ”€ tests/ # cluster validation tests (Pytest) +โ”‚ โ”œโ”€โ”€ health +โ”‚ โ”œโ”€โ”€ ibperf +โ”‚ โ”œโ”€โ”€ rccl +โ”œโ”€โ”€ conftest.py +โ”œโ”€โ”€ pytest.ini +โ”œโ”€โ”€ README.md +``` + +--- + +## โœ… Unit Tests (using `unittest`) + +- Use Python's built-in `unittest` framework +- Each test file should be named `test_*.py` +- Each test class should inherit from `unittest.TestCase` + +### ๐Ÿ”ง How to Run Unit Tests + +From the project root (`cvs/`): + +```bash +python -m unittest discover -s lib/unittests +``` + +This will discover and run all unit tests under `lib/unittests/`. + +--- + +## ๐Ÿ› ๏ธ Run All Unit Tests from Multiple Folders + +If you have multiple unit test folders (e.g., `lib/unittests`, `utils/unittests`), create a script like: + +Run it with: + +```bash +python run_all_unittests.py +test_missing_gpu_key_raises_keyerror (test_html_lib.TestBuildHtmlMemUtilizationTable.test_missing_gpu_key_raises_keyerror) ... Build HTML mem utilization table +ok +test_multiple_nodes (test_html_lib.TestBuildHtmlMemUtilizationTable.test_multiple_nodes) ... Build HTML mem utilization table +ok +test_rocm7_style_gpu_data (test_html_lib.TestBuildHtmlMemUtilizationTable.test_rocm7_style_gpu_data) ... Build HTML mem utilization table +ok +test_single_node_valid_input (test_html_lib.TestBuildHtmlMemUtilizationTable.test_single_node_valid_input) ... Build HTML mem utilization table +ok +test_bytes_only (test_html_lib.TestNormalizeBytes.test_bytes_only) ... ok +test_gigabytes (test_html_lib.TestNormalizeBytes.test_gigabytes) ... ok +test_kilobytes_binary (test_html_lib.TestNormalizeBytes.test_kilobytes_binary) ... ok +test_kilobytes_decimal (test_html_lib.TestNormalizeBytes.test_kilobytes_decimal) ... ok +test_megabytes (test_html_lib.TestNormalizeBytes.test_megabytes) ... ok +test_negative_bytes (test_html_lib.TestNormalizeBytes.test_negative_bytes) ... ok +test_precision (test_html_lib.TestNormalizeBytes.test_precision) ... ok +test_type_error (test_html_lib.TestNormalizeBytes.test_type_error) ... ok +test_invalid_bus_speed (test_verify_lib.TestVerifyGpuPcieBusWidth.test_invalid_bus_speed) ... ok +test_valid_bus_width (test_verify_lib.TestVerifyGpuPcieBusWidth.test_valid_bus_width) ... ok +test_threshold_exceeded (test_verify_lib.TestVerifyGpuPcieErrors.test_threshold_exceeded) ... ok +test_valid_error_metrics (test_verify_lib.TestVerifyGpuPcieErrors.test_valid_error_metrics) ... ok + +---------------------------------------------------------------------- +Ran 16 tests in 0.026s + +OK +``` + +--- + +## ๐Ÿงช Tips for Organizing Tests + +- Keep unit tests close to the code they test (e.g., `lib/unittests/`) +- Use `__init__.py` in all test directories to make them importable +- Use `unittest.mock` to isolate unit tests from external dependencies + +--- diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/parallel_ssh_lib.py b/lib/parallel_ssh_lib.py index 8b031f1c..5228ca2e 100755 --- a/lib/parallel_ssh_lib.py +++ b/lib/parallel_ssh_lib.py @@ -7,12 +7,14 @@ from __future__ import print_function from pssh.clients import ParallelSSHClient +from pssh.exceptions import Timeout import sys import os import re import ast import json +import time # Following used only for scp of file import paramiko @@ -28,22 +30,71 @@ class Pssh(): mandatory args = user, password (or) 'private_key': load_private_key('my_key.pem') """ - def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False ): + def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False, stop_on_errors=True ): - self.log = log - self.host_list = host_list - self.user = user - self.pkey = pkey - self.password = password + self.log = log + self.host_list = host_list + self.reachable_hosts = host_list + self.user = user + self.pkey = pkey + self.password = password self.host_key_check = host_key_check + self.stop_on_errors = stop_on_errors + self.unreachable_hosts = [] if self.password is None: - print(self.host_list) + print(self.reachable_hosts) print(self.user) print(self.pkey) - self.client = ParallelSSHClient( self.host_list, user=self.user, pkey=self.pkey, keepalive_seconds=30 ) + self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30 ) else: - self.client = ParallelSSHClient( self.host_list, user=self.user, password=self.password, keepalive_seconds=30 ) + self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30 ) + + + def check_connectivity(self, host): + """ + Check if a host is reachable via SSH by attempting a simple command. + """ + try: + temp_client = ParallelSSHClient([host], user=self.user, pkey=self.pkey if self.password is None else None, password=self.password, keepalive_seconds=30) + temp_client.run_command('echo 1', stop_on_errors=True, read_timeout=5) + return True + except Exception: + return False + + def prune_unreachable_hosts(self, output): + """ + Prune unreachable hosts from self.reachable_hosts if they have Timeout exceptions and also fails connectivity check. + + Targeted pruning: Only Timeout exceptions trigger pruning to avoid removing hosts for transient failures + like authentication errors or SSH protocol issues, which may succeed on next try. Timeouts are indicative + of potential unreachability, so we perform an additional connectivity check before pruning. This ensures + that hosts are not permanently removed from the list for recoverable errors. + """ + initial_unreachable_len = len(self.unreachable_hosts) + # Only prune on Timeout exceptions to avoid removing hosts for transient issues like auth failures. + # Timeouts indicate potential unreachability, so we check connectivity and prune if confirmed. + failed_hosts = [item.host for item in output if item.exception and isinstance(item.exception, Timeout)] + for host in failed_hosts: + if not self.check_connectivity(host): + print(f"Host {host} is unreachable, pruning from reachable hosts list.") + self.unreachable_hosts.append(host) + self.reachable_hosts.remove(host) + if len(self.unreachable_hosts) > initial_unreachable_len: + # Recreate client with filtered reachable_hosts, only if hosts were actually pruned + if self.password is None: + self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30) + else: + self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30) + + + def inform_unreachability(self, cmd_output): + """ + Update cmd_output with "Host Unreachable" for all hosts in self.unreachable_hosts. + This ensures that the output dictionary reflects the status of pruned hosts. + """ + for host in self.unreachable_hosts: + cmd_output[host] += "\nABORT: Host Unreachable Error" def exec(self, cmd, timeout=None ): @@ -53,9 +104,9 @@ def exec(self, cmd, timeout=None ): cmd_output = {} print(f'cmd = {cmd}') if timeout is None: - output = self.client.run_command(cmd ) + output = self.client.run_command(cmd, stop_on_errors=self.stop_on_errors ) else: - output = self.client.run_command(cmd, read_timeout=timeout ) + output = self.client.run_command(cmd, read_timeout=timeout, stop_on_errors=self.stop_on_errors ) for item in output: print('#----------------------------------------------------------#') print(f'Host == {item.host} ==') @@ -71,8 +122,16 @@ def exec(self, cmd, timeout=None ): print(line) cmd_out_str = cmd_out_str + line.replace( '\t', ' ') cmd_out_str = cmd_out_str + '\n' + if item.exception: + exc_str = str(item.exception).replace('\t', ' ') + print(exc_str) + cmd_out_str += exc_str + '\n' cmd_output[item.host] = cmd_out_str + if not self.stop_on_errors: + self.prune_unreachable_hosts(output) + self.inform_unreachability(cmd_output) + return cmd_output @@ -85,9 +144,9 @@ def exec_cmd_list(self, cmd_list, timeout=None ): cmd_output = {} print(cmd_list) if timeout is None: - output = self.client.run_command( '%s', host_args=cmd_list ) + output = self.client.run_command( '%s', host_args=cmd_list, stop_on_errors=self.stop_on_errors ) else: - output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout ) + output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout, stop_on_errors=self.stop_on_errors ) i = 0 for item in output: print('#----------------------------------------------------------#') @@ -105,9 +164,17 @@ def exec_cmd_list(self, cmd_list, timeout=None ): print(line) cmd_out_str = cmd_out_str + line.replace( '\t', ' ') cmd_out_str = cmd_out_str + '\n' + if item.exception: + exc_str = str(item.exception).replace('\t', ' ') + print(exc_str) + cmd_out_str += exc_str + '\n' i=i+1 cmd_output[item.host] = cmd_out_str + if not self.stop_on_errors: + self.prune_unreachable_hosts(output) + self.inform_unreachability(cmd_output) + return cmd_output @@ -126,7 +193,7 @@ def scp_file(self, local_file, remote_file, recurse=False ): def reboot_connections(self ): print('Rebooting Connections') - self.client.run_command( 'reboot -f' ) + self.client.run_command( 'reboot -f', stop_on_errors=self.stop_on_errors ) def destroy_clients(self ): print('Destroying Current phdl connections ..') diff --git a/lib/unittests/__init__.py b/lib/unittests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/unittests/test_html_lib.py b/lib/unittests/test_html_lib.py new file mode 100644 index 00000000..fffaf6c1 --- /dev/null +++ b/lib/unittests/test_html_lib.py @@ -0,0 +1,172 @@ +import unittest +import tempfile +import os + +# Import the module under test +import lib.html_lib as html_lib + +class TestNormalizeBytes(unittest.TestCase): + + def test_bytes_only(self): + self.assertEqual(html_lib.normalize_bytes(932), "932 B") + + def test_kilobytes_binary(self): + self.assertEqual(html_lib.normalize_bytes(2048), "2 KB") + + def test_kilobytes_decimal(self): + self.assertEqual(html_lib.normalize_bytes(2000, si=True), "2 kB") + + def test_megabytes(self): + self.assertEqual(html_lib.normalize_bytes(5 * 1024 * 1024), "5 MB") + + def test_gigabytes(self): + self.assertEqual(html_lib.normalize_bytes(3 * 1024**3), "3 GB") + + def test_negative_bytes(self): + self.assertEqual(html_lib.normalize_bytes(-1024), "-1 KB") + + def test_precision(self): + self.assertEqual(html_lib.normalize_bytes(1536, precision=1), "1.5 KB") + + def test_type_error(self): + with self.assertRaises(TypeError): + html_lib.normalize_bytes("not a number") + + +class TestBuildHtmlMemUtilizationTable(unittest.TestCase): + + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False, mode='w+', encoding='utf-8') + self.filename = self.tmp_file.name + + def tearDown(self): + self.tmp_file.close() + os.remove(self.filename) + + def test_single_node_valid_input(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + self.assertIn("8192", content) + self.assertIn("10%", content) + + def test_multiple_nodes(self): + use_dict = { + f"node{i}": { + **{f"card{j}": { + "GPU Memory Allocated (VRAM%)": f"{j*10}%", + "GPU Memory Read/Write Activity (%)": f"{j*5}%", + "Memory Activity": f"{j*3}%", + "Avg. Memory Bandwidth": f"{j*2} GB/s" + } for j in range(8)} + } for i in range(2) + } + + amd_dict = { + f"node{i}": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] for i in range(2) + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("node0", content) + self.assertIn("node1", content) + + def test_rocm7_style_gpu_data(self): + use_dict = { + "node1": { + **{f"card{i}": { + "GPU Memory Allocated (VRAM%)": f"{i*10}%", + "GPU Memory Read/Write Activity (%)": f"{i*5}%", + "Memory Activity": f"{i*3}%", + "Avg. Memory Bandwidth": f"{i*2} GB/s" + } for i in range(8)} + } + } + + amd_dict = { + "node1": { + "gpu_data": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + } + + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + with open(self.filename, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("GPU Memory Utilization", content) + self.assertIn("G0 Tot VRAM MB", content) + self.assertIn("node1", content) + + def test_missing_gpu_key_raises_keyerror(self): + use_dict = { + "node1": { + "card0": { + "GPU Memory Allocated (VRAM%)": "10%", + "GPU Memory Read/Write Activity (%)": "20%", + "Memory Activity": "30%", + "Avg. Memory Bandwidth": "40 GB/s" + } + # Missing card1 to card7 + } + } + + amd_dict = { + "node1": [ + { + "mem_usage": { + "total_vram": {"value": "16384"}, + "used_vram": {"value": "8192"}, + "free_vram": {"value": "8192"} + } + } for _ in range(8) + ] + } + + with self.assertRaises(KeyError): + html_lib.build_html_mem_utilization_table(self.filename, use_dict, amd_dict) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/lib/unittests/test_parallel_ssh_lib.py b/lib/unittests/test_parallel_ssh_lib.py new file mode 100644 index 00000000..42140cc2 --- /dev/null +++ b/lib/unittests/test_parallel_ssh_lib.py @@ -0,0 +1,505 @@ +import unittest +from unittest.mock import patch, MagicMock +from lib.parallel_ssh_lib import Pssh + +class TestPsshExec(unittest.TestCase): + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def setUp(self, mock_pssh_client): + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.mock_pssh_client = mock_pssh_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + + def test_exec_successful(self): + # Test: Execute command successfully on all hosts + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['output1 line1', 'output1 line2'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = ['output2 line1'] + mock_output2.stderr = [] + mock_output2.exception = None + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec('echo hello') + + self.mock_client.run_command.assert_called_once_with('echo hello', stop_on_errors=True) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('output1 line1', result['host1']) + self.assertIn('output2 line1', result['host2']) + + def test_exec_with_exception_stop_on_errors_true(self): + # Test: Handle exceptions with stop_on_errors=True (default) + # Exception should be raised, and no result returned (no partial results) + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Connection failed') + + # With stop_on_errors=True, run_command raises on exception, no result returned + with self.assertRaises(Timeout) as cm: + result = self.pssh.exec('echo hello') # This should raise, so result is not assigned + + self.assertIn('Connection failed', str(cm.exception)) + # Since exception was raised, result was not returned + self.assertNotIn('result', locals()) + + @patch.object(Pssh, 'check_connectivity') + def test_exec_with_exception_stop_on_errors_false(self, mock_check_connectivity): + # Test Case 2.2: Execute command with timeout and stop_on_errors=False + # Exception should not be raised instead populated in output for failed hosts, success for others + self.pssh.stop_on_errors = False + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # No pruning + + result = self.pssh.exec('echo hello', timeout=10) + + self.mock_client.run_command.assert_called_once_with('echo hello', read_timeout=10, stop_on_errors=False) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Command timed out', result['host2']) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_with_pruning_unreachable_host(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, timeout on host2, and check_connectivity fails for host2, prune it + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = False # Simulate unreachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(self.pssh.unreachable_hosts, ['host2']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + # Client should be recreated once (init + prune) + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_no_pruning_when_reachable(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, timeout on host2, but check_connectivity succeeds, no pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Always reachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No change + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Command timed out', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_pruning_with_multiple_unreachable_hosts(self, mock_check_connectivity, mock_pssh_client): + # Test: With stop_on_errors=False, multiple hosts (host2, host3) timeout and are unreachable, prune all + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2', 'host3'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + mock_output3 = MagicMock() + mock_output3.host = 'host3' + mock_output3.stdout = [] + mock_output3.stderr = [] + mock_output3.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2, mock_output3] + mock_check_connectivity.return_value = False # Simulate all unreachable + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(sorted(self.pssh.unreachable_hosts), ['host2', 'host3']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host3', result) + self.assertIn('success output', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(result['host3'], 'Command timed out\n\nABORT: Host Unreachable Error') + # Client should be recreated once (init + prune) + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def test_exec_no_pruning_on_non_timeout_exception(self, mock_pssh_client): + # Test: With stop_on_errors=False, non-Timeout exception on host2, no pruning occurs + # Non-Timeout exceptions are treated as transient failures (e.g., authentication or network glitches), + # not indicative of permanent unreachability, so we don't attempt connectivity checks or pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + from pssh.exceptions import ConnectionError + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success output'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = ConnectionError('Connection failed') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec('echo hello', timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No pruning + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success output', result['host1']) + self.assertIn('Connection failed', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch.object(Pssh, 'prune_unreachable_hosts') + @patch.object(Pssh, 'inform_unreachability') + def test_exec_no_pruning_when_stop_on_errors_true(self, mock_inform, mock_prune): + # Test: With stop_on_errors=True, no pruning even with timeout + # Since stop_on_errors=True, run_command raises immediately, so prune_unreachable_hosts and inform_unreachability are not invoked + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout): + self.pssh.exec('echo hello', timeout=10) + + # Assert that pruning methods were not called + mock_prune.assert_not_called() + mock_inform.assert_not_called() + + +class TestPsshExecCmdList(unittest.TestCase): + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def setUp(self, mock_pssh_client): + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.mock_pssh_client = mock_pssh_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + + def test_exec_cmd_list_successful(self): + # Test: Execute different commands on different hosts successfully + cmd_list = ['echo host1', 'echo host2'] + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['host1'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = ['host2'] + mock_output2.stderr = [] + mock_output2.exception = None + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec_cmd_list(cmd_list) + + self.mock_client.run_command.assert_called_once_with('%s', host_args=cmd_list, stop_on_errors=True) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host1', result['host1']) + self.assertIn('host2', result['host2']) + + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_with_exception_stop_on_errors_false(self, mock_check_connectivity): + # Test: Handle exceptions with stop_on_errors=False for exec_cmd_list + # Exception should not be raised instead populated in output for failed hosts, success for others + self.pssh.stop_on_errors = False + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Simulate reachable, no pruning + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.mock_client.run_command.assert_called_once_with('%s', host_args=cmd_list, read_timeout=10, stop_on_errors=False) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Command timed out', result['host2']) + + def test_exec_cmd_list_with_exception_stop_on_errors_true(self): + # Test: Handle exceptions with stop_on_errors=True for exec_cmd_list + # Exception should be raised, and no result returned (no partial results) + cmd_list = ['echo test'] + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout) as cm: + result = self.pssh.exec_cmd_list(cmd_list, timeout=5) + + self.assertIn('Command timed out', str(cm.exception)) + self.assertNotIn('result', locals()) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_no_pruning_when_reachable(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with stop_on_errors=False, timeout on host2, but check_connectivity succeeds, no pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = True # Always reachable + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No change + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Command timed out', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_with_pruning(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + mock_check_connectivity.return_value = False + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(self.pssh.unreachable_hosts, ['host2']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + @patch.object(Pssh, 'check_connectivity') + def test_exec_cmd_list_pruning_with_multiple_unreachable_hosts(self, mock_check_connectivity, mock_pssh_client): + # Test: exec_cmd_list with pruning for multiple unreachable hosts + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2', 'host3'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + self.pssh.check_connectivity = mock_check_connectivity + cmd_list = ['echo success', 'echo fail1', 'echo fail2'] + from pssh.exceptions import Timeout + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = Timeout('Command timed out') + + mock_output3 = MagicMock() + mock_output3.host = 'host3' + mock_output3.stdout = [] + mock_output3.stderr = [] + mock_output3.exception = Timeout('Command timed out') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2, mock_output3] + mock_check_connectivity.return_value = False + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1']) + self.assertEqual(sorted(self.pssh.unreachable_hosts), ['host2', 'host3']) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('host3', result) + self.assertIn('success', result['host1']) + self.assertEqual(result['host2'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(result['host3'], 'Command timed out\n\nABORT: Host Unreachable Error') + self.assertEqual(mock_pssh_client.call_count, 2) + + @patch('lib.parallel_ssh_lib.ParallelSSHClient') + def test_exec_cmd_list_no_pruning_on_non_timeout_exception(self, mock_pssh_client): + # Test: exec_cmd_list with non-Timeout exception, no pruning occurs + # Non-Timeout exceptions are treated as transient failures, so no connectivity checks or pruning + self.mock_client = MagicMock() + mock_pssh_client.return_value = self.mock_client + self.host_list = ['host1', 'host2'] + self.pssh = Pssh('log', self.host_list, user='user', password='pass') + self.pssh.stop_on_errors = False + cmd_list = ['echo success', 'echo fail'] + from pssh.exceptions import ConnectionError + mock_output1 = MagicMock() + mock_output1.host = 'host1' + mock_output1.stdout = ['success'] + mock_output1.stderr = [] + mock_output1.exception = None + + mock_output2 = MagicMock() + mock_output2.host = 'host2' + mock_output2.stdout = [] + mock_output2.stderr = [] + mock_output2.exception = ConnectionError('Connection failed') + + self.mock_client.run_command.return_value = [mock_output1, mock_output2] + + result = self.pssh.exec_cmd_list(cmd_list, timeout=10) + + self.assertEqual(self.pssh.reachable_hosts, ['host1', 'host2']) # No pruning + self.assertEqual(self.pssh.unreachable_hosts, []) + self.assertIn('host1', result) + self.assertIn('host2', result) + self.assertIn('success', result['host1']) + self.assertIn('Connection failed', result['host2']) # Original exception + # Client not recreated + self.assertEqual(mock_pssh_client.call_count, 1) + + @patch.object(Pssh, 'prune_unreachable_hosts') + @patch.object(Pssh, 'inform_unreachability') + def test_exec_cmd_list_no_pruning_when_stop_on_errors_true(self, mock_inform, mock_prune): + # Test: exec_cmd_list with stop_on_errors=True, no pruning even with timeout + # Since stop_on_errors=True, run_command raises immediately, so prune_unreachable_hosts and inform_unreachability are not invoked + cmd_list = ['echo test'] + from pssh.exceptions import Timeout + self.mock_client.run_command.side_effect = Timeout('Command timed out') + + with self.assertRaises(Timeout): + self.pssh.exec_cmd_list(cmd_list, timeout=5) + + # Assert that pruning methods were not called + mock_prune.assert_not_called() + mock_inform.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/lib/unittests/test_verify_lib.py b/lib/unittests/test_verify_lib.py new file mode 100644 index 00000000..17f7ea87 --- /dev/null +++ b/lib/unittests/test_verify_lib.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import MagicMock, patch + + +import lib.verify_lib as verify_lib + +class TestVerifyGpuPcieBusWidth(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_bus_width(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'}, + 'card1': {'PCI Bus': '0000:02:00.0'} + }, + 'node2': { + 'card0': {'PCI Bus': '0000:03:00.0'}, + 'card1': {'PCI Bus': '0000:04:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 32GT/s, Width x16', + 'node2': 'LnkSta: Speed 32GT/s, Width x16' + } + + result = verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=2) + self.assertEqual(result, {'node1': [], 'node2': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_pcie_bus_dict') + @patch('lib.verify_lib.fail_test') + def test_invalid_bus_speed(self, mock_fail_test, mock_get_bus_dict): + mock_get_bus_dict.return_value = { + 'node1': { + 'card0': {'PCI Bus': '0000:01:00.0'} + } + } + + phdl = MagicMock() + phdl.exec_cmd_list.return_value = { + 'node1': 'LnkSta: Speed 16GT/s, Width x16' + } + + verify_lib.verify_gpu_pcie_bus_width(phdl, expected_cards=1) + mock_fail_test.assert_called() + + +class TestVerifyGpuPcieErrors(unittest.TestCase): + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_valid_error_metrics(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '10', + 'pcie_nak_sent_count_acc (Count)': '20', + 'pcie_nak_rcvd_count_acc (Count)': '30' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(result, {'node1': []}) + mock_fail_test.assert_not_called() + + @patch('lib.verify_lib.get_gpu_metrics_dict') + @patch('lib.verify_lib.fail_test') + def test_threshold_exceeded(self, mock_fail_test, mock_get_metrics): + mock_get_metrics.return_value = { + 'node1': { + 'card0': { + 'pcie_l0_to_recov_count_acc (Count)': '101', + 'pcie_nak_sent_count_acc (Count)': '150', + 'pcie_nak_rcvd_count_acc (Count)': '200' + } + } + } + + phdl = MagicMock() + result = verify_lib.verify_gpu_pcie_errors(phdl) + self.assertEqual(len(result['node1']), 3) + mock_fail_test.assert_called() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/run_all_unittests.py b/run_all_unittests.py new file mode 100644 index 00000000..7964218a --- /dev/null +++ b/run_all_unittests.py @@ -0,0 +1,23 @@ +# run_all_unittests.py +import sys +import os + +# Add lib directory to sys.path for absolute imports +lib_path = os.path.join(os.path.dirname(__file__), 'lib') +sys.path.insert(0, lib_path) + +import unittest + +def main(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add all unit test directories + for test_dir in ['lib/unittests']: + suite.addTests(loader.discover(start_dir=test_dir)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tests/health/agfhc_cvs.py b/tests/health/agfhc_cvs.py index 78c7abb4..f9795eeb 100644 --- a/tests/health/agfhc_cvs.py +++ b/tests/health/agfhc_cvs.py @@ -136,7 +136,7 @@ def phdl(cluster_dict): nhdl_dict = {} print(cluster_dict) node_list = list(cluster_dict['node_dict'].keys()) - phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'] ) + phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'], stop_on_errors=False ) return phdl diff --git a/tests/health/rvs_cvs.py b/tests/health/rvs_cvs.py index 6306bda0..c159c761 100644 --- a/tests/health/rvs_cvs.py +++ b/tests/health/rvs_cvs.py @@ -66,7 +66,7 @@ def config_dict(config_file, cluster_dict): def phdl(cluster_dict): print(cluster_dict) node_list = list(cluster_dict['node_dict'].keys()) - phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'] ) + phdl = Pssh( log, node_list, user=cluster_dict['username'], pkey=cluster_dict['priv_key_file'], stop_on_errors=False ) return phdl diff --git a/tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc b/tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc deleted file mode 100644 index bea64d7a..00000000 Binary files a/tests/inference/inferencemax/__pycache__/inference_max_benchmark.cpython-312-pytest-8.4.1.pyc and /dev/null differ