diff --git a/tests/run_tests_distributed.py b/tests/run_tests_distributed.py index 8c4af8e1..abf32966 100755 --- a/tests/run_tests_distributed.py +++ b/tests/run_tests_distributed.py @@ -44,6 +44,9 @@ def _distributed_worker(rank, world_size, test_file, pytest_args): try: # Run pytest directly in this process exit_code = pytest.main([test_file] + pytest_args) + # If tests failed, exit with the failure code + if exit_code != 0: + sys.exit(exit_code) return exit_code finally: # Restore original argv @@ -82,7 +85,19 @@ def main(): print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}") # Run all tests within a single distributed process group - mp.spawn(_distributed_worker, args=(num_ranks, test_file, pytest_args), nprocs=num_ranks, join=True) + try: + mp.spawn( + _distributed_worker, + args=(num_ranks, test_file, pytest_args), + nprocs=num_ranks, + join=True, + ) + except SystemExit as e: + # Catch sys.exit() from worker and return same exit code + sys.exit(e.code if isinstance(e.code, int) else 1) + except Exception: + # Any other unhandled exception = failure + sys.exit(1) if __name__ == "__main__":