diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b572a7945216..4ca0bcc162fb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,7 +163,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) -- +- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559)) - diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 615f461055204..ff95e89d1d2cf 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -24,6 +24,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ) -class _LiteModule(nn.Module): +class _LiteModule(DeviceDtypeModuleMixin): def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 4993a10c8dbc2..c271d3b3163ed 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,6 +17,7 @@ import torch from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from tests.helpers.runif import RunIf @@ -65,6 +66,27 @@ def check_autocast(forward_input): assert out.dtype == input_type or out.dtype == torch.get_default_dtype() +@pytest.mark.parametrize( + "device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))] +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_lite_module_device_dtype_propagation(device, dtype): + """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics).""" + + class DeviceModule(DeviceDtypeModuleMixin): + pass + + device_module = DeviceModule() + lite_module = _LiteModule(device_module, Mock()) + lite_module.to(device) + assert device_module.device == device + assert lite_module.device == device + + lite_module.to(dtype) + assert device_module.dtype == dtype + assert lite_module.dtype == dtype + + def test_lite_dataloader_iterator(): """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic device placement)."""