diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6aafd01fe9a29..a26b68647aadf 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -391,6 +391,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809)) +- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892)) + + + ### Fixed - Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826)) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index c3d53f5712a44..d2335154d5ddb 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -35,7 +35,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp -from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks @@ -1440,6 +1440,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ + if _TORCH_GREATER_EQUAL_2_1: + # ShardedTensor is deprecated in favor of DistributedTensor + return if _IS_WINDOWS or not torch.distributed.is_available(): rank_zero_debug("Could not register sharded tensor state dict hooks") return diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index c7ad77f0dcbfb..2b4b2c5cb201e 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -310,7 +310,7 @@ def assert_device(device: torch.device) -> None: assert_device(torch.device("cpu")) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, max_torch="2.1.0") def test_sharded_tensor_state_dict(single_process_pg): from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty from torch.distributed._sharding_spec import ChunkShardingSpec