Skip to content

Commit

Permalink
Updated lightning version requirement for FSDP (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
jyothisambolu committed May 3, 2024
1 parent 34b57fd commit 7f4e930
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for LightningCLI. ([#173](https://github.com/Lightning-AI/lightning-Habana/pull/173))
- Added experimental support for FSDP on HPU. ([#174](https://github.com/Lightning-AI/lightning-Habana/pull/174))
- Added support for FP8 inference with DeepSpeed. ([#176](https://github.com/Lightning-AI/lightning-Habana/pull/176))
- Updated the lightning version check for using FSDP. ([#182](https://github.com/Lightning-AI/lightning-Habana/pull/182))


### Changed
Expand Down
6 changes: 6 additions & 0 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ Limitations of FSDP on HPU

For more details on the supported FSDP features and functionalities, and limitations refer to `Using Fully Sharded Data Parallel (FSDP) with Intel Gaudi <https://docs.habana.ai/en/latest/PyTorch/PyTorch_FSDP/Pytorch_FSDP.html>`_.

.. note::

This is an experimental feature.
This feature requires lightning/pytorch-lightning >= 2.3.0 or install nightly from the source.


----

Using HPU Graphs
Expand Down
6 changes: 3 additions & 3 deletions src/lightning_habana/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from lightning_habana.pytorch.plugins.io_plugin import HPUCheckpointIO
from lightning_habana.pytorch.strategies.parallel import HPUParallelStrategy, _hpu_broadcast_object_list
from lightning_habana.utils.hpu_distributed import _sync_ddp_if_available
from lightning_habana.utils.imports import _HABANA_FRAMEWORK_AVAILABLE, _LIGHTNING_LESSER_EQUAL_2_2_3
from lightning_habana.utils.imports import _HABANA_FRAMEWORK_AVAILABLE, _LIGHTNING_GREATER_EQUAL_2_3_0

if _HABANA_FRAMEWORK_AVAILABLE:
import habana_frameworks.torch.distributed.hccl as hpu_dist
Expand Down Expand Up @@ -105,8 +105,8 @@ def __init__(
state_dict_type: Literal["full", "sharded"] = "full",
**kwargs: Any,
) -> None:
if _LIGHTNING_LESSER_EQUAL_2_2_3:
raise OSError("HPUFSDPStrategy requires `pytorch-lightning > 2.2.3`.")
if not _LIGHTNING_GREATER_EQUAL_2_3_0:
raise OSError("HPUFSDPStrategy requires `lightning>=2.3.0 or pytorch-lightning >= 2.3.0`.")
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_habana/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
_LIGHTNING_GREATER_EQUAL_2_0_0 = compare_version("lightning", operator.ge, "2.0.0") or compare_version(
"pytorch_lightning", operator.ge, "2.0.0"
)
_LIGHTNING_LESSER_EQUAL_2_2_3 = compare_version("lightning", operator.le, "2.2.3") or compare_version(
"pytorch_lightning", operator.le, "2.2.3"
_LIGHTNING_GREATER_EQUAL_2_3_0 = compare_version("lightning", operator.ge, "2.3.0", True) or compare_version(
"pytorch_lightning", operator.ge, "2.3.0", True
)
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_KINETO_AVAILABLE = torch.profiler.kineto_available()

0 comments on commit 7f4e930

Please sign in to comment.