Skip to content

Commit

Permalink
make device property always return a device with index (#4851)
Browse files Browse the repository at this point in the history
* make device property always return a device with index

* pep8

* Update test_dtype_device_mixin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
justusschock and Borda committed Nov 26, 2020
1 parent 204a0a2 commit 742ddd8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/device_dtype_mixin.py
Expand Up @@ -37,7 +37,13 @@ def dtype(self, new_dtype: Union[str, torch.dtype]):

@property
def device(self) -> Union[str, torch.device]:
return self._device
device = self._device

# make this more explicit to always include the index
if device.type == 'cuda' and device.index is None:
return torch.device(f'cuda:{torch.cuda.current_device()}')

return device

@device.setter
def device(self, new_device: Union[str, torch.device]):
Expand Down
14 changes: 13 additions & 1 deletion tests/utilities/test_dtype_device_mixin.py
Expand Up @@ -55,7 +55,6 @@ def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx)
])
@pytest.mark.parametrize(['dst_device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda')),
pytest.param(torch.device('cuda', 0)),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down Expand Up @@ -100,3 +99,16 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
max_steps=1,
)
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_gpu_device_includes_index():
model = TopModule()

# explicitly call without an index to see if the returning device contains an index (it should!)
model.cuda()

device = model.device
assert device.type == 'cuda'
assert device.index is not None
assert device.index == torch.cuda.current_device()

0 comments on commit 742ddd8

Please sign in to comment.