diff --git a/graph_net/torch/test_reference_device.py b/graph_net/torch/test_reference_device.py index bb80c1e8c..d5f1b2891 100644 --- a/graph_net/torch/test_reference_device.py +++ b/graph_net/torch/test_reference_device.py @@ -2,6 +2,7 @@ import os import sys import types +import torch from pathlib import Path from graph_net_bench import path_utils @@ -78,6 +79,7 @@ def main(args): assert args.device in ["cuda", "cpu"] eval_backend_perf.set_seed(args.seed) + torch.set_default_device(args.device) ref_dump_dir = Path(args.reference_dir) ref_dump_dir.mkdir(parents=True, exist_ok=True) diff --git a/graph_net_bench/torch/eval_backend_perf.py b/graph_net_bench/torch/eval_backend_perf.py index 5c8586f30..3fd6db3ff 100644 --- a/graph_net_bench/torch/eval_backend_perf.py +++ b/graph_net_bench/torch/eval_backend_perf.py @@ -195,6 +195,7 @@ def measure_performance(model_call, args, compiler): def eval_single_model_with_single_backend(args): check_and_complete_args(args) set_seed(args.seed) + torch.set_default_device(args.device) os.makedirs(args.output_path, exist_ok=True) log_path = utils.get_log_path(args.output_path, args.model_path) output_dump_path = utils.get_output_path(args.output_path, args.model_path) diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index 8ee670fd2..0923e19d6 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -487,6 +487,7 @@ def main(args): initalize_seed = 123 set_seed(random_seed=initalize_seed) + torch.set_default_device(args.device) if path_utils.is_single_model_dir(args.model_path): test_single_model(args)