Skip to content

Commit

Permalink
Fix select_device() for Multi-GPU (#6434)
Browse files Browse the repository at this point in the history
* Fix `select_device()` for Multi-GPU

Possible fix for ultralytics/yolov5#6431

* Update torch_utils.py

* Update torch_utils.py

* Update torch_utils.py

* Update torch_utils.py

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
SecretStar112 committed Jan 26, 2022
1 parent cc16da8 commit 4af08b9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import (LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
from utils.torch_utils import torch_distributed_zero_first
from utils.torch_utils import device_count, torch_distributed_zero_first

# Parameters
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = ['bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp'] # include image suffixes
VID_FORMATS = ['asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'wmv'] # include video suffixes
DEVICE_COUNT = max(torch.cuda.device_count(), 1)
DEVICE_COUNT = max(device_count(), 1) # number of CUDA devices

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
Expand Down
15 changes: 12 additions & 3 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
return '' # not a git repository


def device_count():
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count().
try:
cmd = 'nvidia-smi -L | wc -l'
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
except Exception as e:
return 0


def select_device(device='', batch_size=0, newline=True):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
Expand All @@ -61,10 +70,10 @@ def select_device(device='', batch_size=0, newline=True):
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
nd = torch.cuda.device_count() # number of CUDA devices
assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'
nd = device_count() # number of CUDA devices
assert nd > int(max(device.split(','))), f'Invalid `--device {device}` request, valid devices are 0 - {nd - 1}'
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable (must be after asserts)
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'

cuda = not cpu and torch.cuda.is_available()
if cuda:
Expand Down

0 comments on commit 4af08b9

Please sign in to comment.