Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch

# In Lightning Studio, the `lightning` package comes pre-installed.
# Uninstall it first to ensure the editable install works correctly.
setup:
setup: update
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
uv pip install -r requirements.txt \
-r requirements/pytorch/base.txt \
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).


- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))


---

## [2.5.5] - 2025-09-05
Expand Down
20 changes: 10 additions & 10 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "bf16-true"):
rank_zero_warn(
f"FSDP with `{self.precision}` enables computation in lower precision. "
"FSDP will always retain a full-precision copy of the model parameters for sharding."
)

if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed synchronization of gradients in manual optimization with `DDPStrategy(static_graph=True)` ([#21251](https://github.com/Lightning-AI/pytorch-lightning/pull/21251))


- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))


---

Expand Down
20 changes: 10 additions & 10 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -94,19 +95,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "bf16-true"):
rank_zero_warn(
f"FSDP with `{self.precision}` enables computation in lower precision. "
"FSDP will always retain a full-precision copy of the model parameters for sharding."
)

if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
18 changes: 15 additions & 3 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from unittest.mock import Mock

import pytest
Expand All @@ -21,19 +22,30 @@
from tests_fabric.helpers.runif import RunIf


# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
@contextmanager
def null_ctx(*args, **kwargs):
yield


@pytest.mark.parametrize(
("precision", "expected"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float16, torch.float16, torch.float16)),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
config = plugin.mixed_precision_config

warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx

with warning_ctx(UserWarning, match="enables computation in lower precision"):
config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
Expand Down
10 changes: 2 additions & 8 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,9 @@ def step(self, model, batch):

precision = self.fabric._precision
assert isinstance(precision, FSDPPrecision)
if precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif precision.precision == "16-true":
if precision.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-true":
elif precision.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {precision.precision}")
Expand Down
18 changes: 15 additions & 3 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from unittest.mock import ANY, MagicMock, Mock

import pytest
Expand All @@ -21,19 +22,30 @@
from tests_pytorch.helpers.runif import RunIf


# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
@contextmanager
def null_ctx(*args, **kwargs):
yield


@pytest.mark.parametrize(
("precision", "expected"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float16, torch.float16, torch.float16)),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
config = plugin.mixed_precision_config

warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx

with warning_ctx(UserWarning, match="enables computation in lower precision"):
config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
Expand Down
30 changes: 11 additions & 19 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,12 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "32-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")

Expand Down Expand Up @@ -138,16 +134,12 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "32-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")

Expand Down Expand Up @@ -227,7 +219,7 @@ def test_strategy_sync_batchnorm(tmp_path):
accelerator="gpu",
devices=2,
strategy="fsdp",
precision="16-mixed",
precision="32-true",
max_epochs=1,
sync_batchnorm=True,
)
Expand Down Expand Up @@ -267,7 +259,7 @@ def training_step(self, batch, batch_idx):

@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("precision", ["32-true", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
Expand Down Expand Up @@ -359,7 +351,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
accelerator="gpu",
devices=2,
strategy=strategy,
precision="16-mixed",
precision="32-true",
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
Expand Down
Loading