diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py index b2276a74327cd..82936017a68bc 100644 --- a/.automation_scripts/run_pytorch_unit_tests.py +++ b/.automation_scripts/run_pytorch_unit_tests.py @@ -327,16 +327,14 @@ def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_curr return selected_results_dict -def run_test_and_summarize_results() -> Dict[str, Any]: - # parse args - args = parse_args() - pytorch_root_dir = str(args.pytorch_root) - priority_tests = bool(args.priority_tests) - test_config = list[str](args.test_config) - default_list = list[str](args.default_list) - distributed_list = list[str](args.distributed_list) - inductor_list = list[str](args.inductor_list) - skip_rerun = bool(args.skip_rerun) +def run_test_and_summarize_results( + pytorch_root_dir: str, + priority_tests: bool, + test_config: List[str], + default_list: List[str], + distributed_list: List[str], + inductor_list: List[str], + skip_rerun: bool) -> Dict[str, Any]: # copy current environment variables _environ = dict(os.environ) @@ -388,8 +386,8 @@ def run_test_and_summarize_results() -> Dict[str, Any]: CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME # Check multi gpu availability if distributed tests are enabled - if ("distributed" in args.test_config) or len(args.distributed_list) != 0: - check_num_gpus_for_distributed(); + if ("distributed" in test_config) or len(distributed_list) != 0: + check_num_gpus_for_distributed() # Install test requirements command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt" @@ -511,7 +509,8 @@ def check_num_gpus_for_distributed(): assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests" def main(): - all_tests_results = run_test_and_summarize_results() + args = parse_args() + all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun) pprint(dict(all_tests_results)) if __name__ == "__main__":