Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
VENV_DIR = test_venv
PYTHON := $(shell command -v python3 || command -v python)
PIP = $(VENV_DIR)/bin/pip

.PHONY: help venv install test clean all clean_venv
all: venv install test

help:
@echo "Available targets:"
@echo " venv - Create virtual environment"
@echo " install - Install requirements"
@echo " test - Test cvs list and cvs generate commands"
@echo " all - Run venv, install, and test"
@echo " clean - Remove virtual environment and Python cache files"

venv: clean_venv
@echo "Creating virtual environment..."
$(PYTHON) -m venv $(VENV_DIR)

install: venv
@echo "Installing from built distribution..."
$(PIP) install -r requirements.txt

test: install
@echo "Unit Testing cvs..."
$(VENV_DIR)/bin/python run_all_unittests.py

clean_venv:
@echo "Removing virtual environment..."
@if [ -n "$$VIRTUAL_ENV" ] && [ "$$VIRTUAL_ENV" = "$$(pwd)/$(VENV_DIR)" ]; then \
echo "ERROR: You are currently in the venv. Please run 'deactivate' first."; \
exit 1; \
fi
rm -rf $(VENV_DIR)

clean_pycache:
@echo "Removing Python cache files..."
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
find . -name "*.pyc" -delete 2>/dev/null || true
find . -name "*.pyo" -delete 2>/dev/null || true

clean: clean_venv clean_pycache
92 changes: 92 additions & 0 deletions UNIT_TESTING_GUIDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

# Unit Test Organization

This project separates **unit tests** and **cluster validation tests** to maintain clarity and modularity.

---

## 📁 Directory Structure

```
cvs/
├── lib/
│ ├── docker_lib.py
│ ├── utils_lib.py
│ ├── verify_lib.py
│ └── unittests/ # Unit tests for lib modules
│ ├── __init__.py
│ ├── test_verify_lib.py
│ └── test_html_lib.py
├── tests/ # cluster validation tests (Pytest)
│ ├── health
│ ├── ibperf
│ ├── rccl
├── conftest.py
├── pytest.ini
├── README.md
```

---

## ✅ Unit Tests (using `unittest`)

- Use Python's built-in `unittest` framework
- Each test file should be named `test_*.py`
- Each test class should inherit from `unittest.TestCase`

### 🔧 How to Run Unit Tests

From the project root (`cvs/`):

```bash
python -m unittest discover -s lib/unittests
```

This will discover and run all unit tests under `lib/unittests/`.

---

## 🛠️ Run All Unit Tests from Multiple Folders

If you have multiple unit test folders (e.g., `lib/unittests`, `utils/unittests`), create a script like:

Run it with:

```bash
python run_all_unittests.py
test_missing_gpu_key_raises_keyerror (test_html_lib.TestBuildHtmlMemUtilizationTable.test_missing_gpu_key_raises_keyerror) ... Build HTML mem utilization table
ok
test_multiple_nodes (test_html_lib.TestBuildHtmlMemUtilizationTable.test_multiple_nodes) ... Build HTML mem utilization table
ok
test_rocm7_style_gpu_data (test_html_lib.TestBuildHtmlMemUtilizationTable.test_rocm7_style_gpu_data) ... Build HTML mem utilization table
ok
test_single_node_valid_input (test_html_lib.TestBuildHtmlMemUtilizationTable.test_single_node_valid_input) ... Build HTML mem utilization table
ok
test_bytes_only (test_html_lib.TestNormalizeBytes.test_bytes_only) ... ok
test_gigabytes (test_html_lib.TestNormalizeBytes.test_gigabytes) ... ok
test_kilobytes_binary (test_html_lib.TestNormalizeBytes.test_kilobytes_binary) ... ok
test_kilobytes_decimal (test_html_lib.TestNormalizeBytes.test_kilobytes_decimal) ... ok
test_megabytes (test_html_lib.TestNormalizeBytes.test_megabytes) ... ok
test_negative_bytes (test_html_lib.TestNormalizeBytes.test_negative_bytes) ... ok
test_precision (test_html_lib.TestNormalizeBytes.test_precision) ... ok
test_type_error (test_html_lib.TestNormalizeBytes.test_type_error) ... ok
test_invalid_bus_speed (test_verify_lib.TestVerifyGpuPcieBusWidth.test_invalid_bus_speed) ... ok
test_valid_bus_width (test_verify_lib.TestVerifyGpuPcieBusWidth.test_valid_bus_width) ... ok
test_threshold_exceeded (test_verify_lib.TestVerifyGpuPcieErrors.test_threshold_exceeded) ... ok
test_valid_error_metrics (test_verify_lib.TestVerifyGpuPcieErrors.test_valid_error_metrics) ... ok

----------------------------------------------------------------------
Ran 16 tests in 0.026s

OK
```

---

## 🧪 Tips for Organizing Tests

- Keep unit tests close to the code they test (e.g., `lib/unittests/`)
- Use `__init__.py` in all test directories to make them importable
- Use `unittest.mock` to isolate unit tests from external dependencies

---
Empty file added lib/__init__.py
Empty file.
95 changes: 81 additions & 14 deletions lib/parallel_ssh_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from __future__ import print_function
from pssh.clients import ParallelSSHClient
from pssh.exceptions import Timeout

import sys
import os
import re
import ast
import json
import time

# Following used only for scp of file
import paramiko
Expand All @@ -28,22 +30,71 @@ class Pssh():
mandatory args = user, password (or) 'private_key': load_private_key('my_key.pem')
"""

def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False ):
def __init__(self, log, host_list, user=None, password=None, pkey='id_rsa', host_key_check=False, stop_on_errors=True ):

self.log = log
self.host_list = host_list
self.user = user
self.pkey = pkey
self.password = password
self.log = log
self.host_list = host_list
self.reachable_hosts = host_list
self.user = user
self.pkey = pkey
self.password = password
self.host_key_check = host_key_check
self.stop_on_errors = stop_on_errors
self.unreachable_hosts = []

if self.password is None:
print(self.host_list)
print(self.reachable_hosts)
print(self.user)
print(self.pkey)
self.client = ParallelSSHClient( self.host_list, user=self.user, pkey=self.pkey, keepalive_seconds=30 )
self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30 )
else:
self.client = ParallelSSHClient( self.host_list, user=self.user, password=self.password, keepalive_seconds=30 )
self.client = ParallelSSHClient( self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30 )


def check_connectivity(self, host):
"""
Check if a host is reachable via SSH by attempting a simple command.
"""
try:
temp_client = ParallelSSHClient([host], user=self.user, pkey=self.pkey if self.password is None else None, password=self.password, keepalive_seconds=30)
temp_client.run_command('echo 1', stop_on_errors=True, read_timeout=5)
return True
except Exception:
return False

def prune_unreachable_hosts(self, output):
"""
Prune unreachable hosts from self.reachable_hosts if they have Timeout exceptions and also fails connectivity check.

Targeted pruning: Only Timeout exceptions trigger pruning to avoid removing hosts for transient failures
like authentication errors or SSH protocol issues, which may succeed on next try. Timeouts are indicative
of potential unreachability, so we perform an additional connectivity check before pruning. This ensures
that hosts are not permanently removed from the list for recoverable errors.
"""
initial_unreachable_len = len(self.unreachable_hosts)
# Only prune on Timeout exceptions to avoid removing hosts for transient issues like auth failures.
# Timeouts indicate potential unreachability, so we check connectivity and prune if confirmed.
failed_hosts = [item.host for item in output if item.exception and isinstance(item.exception, Timeout)]
for host in failed_hosts:
if not self.check_connectivity(host):
print(f"Host {host} is unreachable, pruning from reachable hosts list.")
self.unreachable_hosts.append(host)
self.reachable_hosts.remove(host)
if len(self.unreachable_hosts) > initial_unreachable_len:
# Recreate client with filtered reachable_hosts, only if hosts were actually pruned
if self.password is None:
self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, pkey=self.pkey, keepalive_seconds=30)
else:
self.client = ParallelSSHClient(self.reachable_hosts, user=self.user, password=self.password, keepalive_seconds=30)


def inform_unreachability(self, cmd_output):
"""
Update cmd_output with "Host Unreachable" for all hosts in self.unreachable_hosts.
This ensures that the output dictionary reflects the status of pruned hosts.
"""
for host in self.unreachable_hosts:
cmd_output[host] += "\nABORT: Host Unreachable Error"


def exec(self, cmd, timeout=None ):
Expand All @@ -53,9 +104,9 @@ def exec(self, cmd, timeout=None ):
cmd_output = {}
print(f'cmd = {cmd}')
if timeout is None:
output = self.client.run_command(cmd )
output = self.client.run_command(cmd, stop_on_errors=self.stop_on_errors )
else:
output = self.client.run_command(cmd, read_timeout=timeout )
output = self.client.run_command(cmd, read_timeout=timeout, stop_on_errors=self.stop_on_errors )
for item in output:
print('#----------------------------------------------------------#')
print(f'Host == {item.host} ==')
Expand All @@ -71,8 +122,16 @@ def exec(self, cmd, timeout=None ):
print(line)
cmd_out_str = cmd_out_str + line.replace( '\t', ' ')
cmd_out_str = cmd_out_str + '\n'
if item.exception:
exc_str = str(item.exception).replace('\t', ' ')
print(exc_str)
cmd_out_str += exc_str + '\n'
cmd_output[item.host] = cmd_out_str

if not self.stop_on_errors:
self.prune_unreachable_hosts(output)
self.inform_unreachability(cmd_output)

return cmd_output


Expand All @@ -85,9 +144,9 @@ def exec_cmd_list(self, cmd_list, timeout=None ):
cmd_output = {}
print(cmd_list)
if timeout is None:
output = self.client.run_command( '%s', host_args=cmd_list )
output = self.client.run_command( '%s', host_args=cmd_list, stop_on_errors=self.stop_on_errors )
else:
output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout )
output = self.client.run_command( '%s', host_args=cmd_list, read_timeout=timeout, stop_on_errors=self.stop_on_errors )
i = 0
for item in output:
print('#----------------------------------------------------------#')
Expand All @@ -105,9 +164,17 @@ def exec_cmd_list(self, cmd_list, timeout=None ):
print(line)
cmd_out_str = cmd_out_str + line.replace( '\t', ' ')
cmd_out_str = cmd_out_str + '\n'
if item.exception:
exc_str = str(item.exception).replace('\t', ' ')
print(exc_str)
cmd_out_str += exc_str + '\n'
i=i+1
cmd_output[item.host] = cmd_out_str

if not self.stop_on_errors:
self.prune_unreachable_hosts(output)
self.inform_unreachability(cmd_output)

return cmd_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test files need to check this cmd_output for ""Host Unreachable" string and do some action in its phdl obj to remove the bad host ?

Right now, its executing the commands on the bad node as well. and returning with ERROR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action required by the test scripts, the lib will prune the node itself and inform

cmd_output[host] += "\nABORT: Host Unreachable Error"
i the results.



Expand All @@ -126,7 +193,7 @@ def scp_file(self, local_file, remote_file, recurse=False ):

def reboot_connections(self ):
print('Rebooting Connections')
self.client.run_command( 'reboot -f' )
self.client.run_command( 'reboot -f', stop_on_errors=self.stop_on_errors )

def destroy_clients(self ):
print('Destroying Current phdl connections ..')
Expand Down
Empty file added lib/unittests/__init__.py
Empty file.
Loading