diff --git a/test/test_cuda.py b/test/test_cuda.py index 9dd18eb12cfbb..1f6decf765ea6 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -442,6 +442,9 @@ def test_out_of_memory_retry(self): IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" ) def test_set_per_process_memory_fraction(self): + if torch.version.hip and ('gfx1101' in torch.cuda.get_device_properties(0).gcnArchName): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() orig = torch.cuda.get_per_process_memory_fraction(0) try: # test invalid fraction value.