diff --git a/caffe2/python/utils.py b/caffe2/python/utils.py index 5e87df8058e01..94e370ffbb7c0 100644 --- a/caffe2/python/utils.py +++ b/caffe2/python/utils.py @@ -237,7 +237,7 @@ def ConvertProtoToBinary(proto_class, filename, out_filename): def GetGPUMemoryUsageStats(): - """Get GPU memory usage stats from CUDAContext. This requires flag + """Get GPU memory usage stats from CUDAContext/HIPContext. This requires flag --caffe2_gpu_memory_tracking to be enabled""" from caffe2.python import workspace, core workspace.RunOperatorOnce( @@ -245,7 +245,7 @@ def GetGPUMemoryUsageStats(): "GetGPUMemoryUsage", [], ["____mem____"], - device_option=core.DeviceOption(caffe2_pb2.CUDA, 0), + device_option=core.DeviceOption(caffe2_pb2.CUDA if workspace.has_gpu_support else caffe2_pb2.HIP, 0), ), ) b = workspace.FetchBlob("____mem____")