Skip to content

Commit

Permalink
Prepare for ShardedTensor deprecation (#16892)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Mar 6, 2023
1 parent 24c0cd7 commit a00e061
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -391,6 +391,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))


- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892))



### Fixed

- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/core/module.py
Expand Up @@ -35,7 +35,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
Expand Down Expand Up @@ -1440,6 +1440,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if _TORCH_GREATER_EQUAL_2_1:
# ShardedTensor is deprecated in favor of DistributedTensor
return
if _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_module.py
Expand Up @@ -310,7 +310,7 @@ def assert_device(device: torch.device) -> None:
assert_device(torch.device("cpu"))


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, max_torch="2.1.0")
def test_sharded_tensor_state_dict(single_process_pg):
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec
Expand Down

0 comments on commit a00e061

Please sign in to comment.