diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index d16d2659405cf..53203e7ae6a01 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -3,6 +3,7 @@ ''' import gc +import os import subprocess import numpy as np @@ -198,19 +199,10 @@ def get_memory_profile(mode): memory_map = get_gpu_memory_map() if mode == 'min_max': - min_mem = 1000000 - min_k = None - max_mem = 0 - max_k = None - for k, v in memory_map: - if v > max_mem: - max_mem = v - max_k = k - if v < min_mem: - min_mem = v - min_k = k - - memory_map = {min_k: min_mem, max_k: max_mem} + min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) + max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) + + memory_map = {min_index: min_memory, max_index: max_memory} return memory_map @@ -224,17 +216,18 @@ def get_gpu_memory_map(): Keys are device ids as integers. Values are memory usage as integers in MB. """ - result = subprocess.check_output( + result = subprocess.run( [ - 'nvidia-smi', '--query-gpu=memory.used', - '--format=csv,nounits,noheader' - ], encoding='utf-8') + 'nvidia-smi', + '--query-gpu=memory.used', + '--format=csv,nounits,noheader', + ], + encoding='utf-8', + capture_output=True, + check=True) # Convert lines into a dictionary - gpu_memory = [int(x) for x in result.strip().split('\n')] - gpu_memory_map = {} - for k, v in zip(range(len(gpu_memory)), gpu_memory): - k = f'gpu_{k}' - gpu_memory_map[k] = v + gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)] + gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)} return gpu_memory_map diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index 118a3f2a521d3..17f8a8fd41b5d 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -224,7 +224,7 @@ def test_multi_gpu_model_dp(): testing_utils.run_gpu_model_test(trainer_options, model, hparams) # test memory helper functions - memory.get_gpu_memory_map() + memory.get_memory_profile('min_max') def test_ddp_sampler_error():