diff --git a/Makefile b/Makefile index 82ea3cc6b2bfc..f342010a14e87 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch # In Lightning Studio, the `lightning` package comes pre-installed. # Uninstall it first to ensure the editable install works correctly. -setup: +setup: update uv pip uninstall lightning pytorch-lightning lightning-fabric || true uv pip install -r requirements.txt \ -r requirements/pytorch/base.txt \ diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index b8c1af644ad4a..74e5fa45a11d1 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)). +- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 270a67e3a2338..189135e7b19e8 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -24,6 +24,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.types import Optimizable if TYPE_CHECKING: @@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module: def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.precision == "16-true": + if self.precision in ("16-true", "bf16-true"): + rank_zero_warn( + f"FSDP with `{self.precision}` enables computation in lower precision. " + "FSDP will always retain a full-precision copy of the model parameters for sharding." + ) + + if self.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-true": + elif self.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float32 + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a409a522973c1..ea350f6899c24 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed synchronization of gradients in manual optimization with `DDPStrategy(static_graph=True)` ([#21251](https://github.com/Lightning-AI/pytorch-lightning/pull/21251)) +- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361)) + --- diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index f3bab3e915e91..337c6a465278d 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -24,6 +24,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -94,19 +95,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.precision == "16-true": + if self.precision in ("16-true", "bf16-true"): + rank_zero_warn( + f"FSDP with `{self.precision}` enables computation in lower precision. " + "FSDP will always retain a full-precision copy of the model parameters for sharding." + ) + + if self.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-true": + elif self.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float32 + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index b15e8e6c65f57..7507002ab4fd1 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from unittest.mock import Mock import pytest @@ -21,19 +22,30 @@ from tests_fabric.helpers.runif import RunIf +# Pytest passes args/kwargs to the context manager used with `pytest.warns`. +# `contextlib.nullcontext` doesn't accept them, so this no-op version does. +@contextmanager +def null_ctx(*args, **kwargs): + yield + + @pytest.mark.parametrize( ("precision", "expected"), [ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-mixed", (torch.float16, torch.float16, torch.float16)), + ("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), ("32-true", (torch.float32, torch.float32, torch.float32)), ], ) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision) - config = plugin.mixed_precision_config + + warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx + + with warning_ctx(UserWarning, match="enables computation in lower precision"): + config = plugin.mixed_precision_config assert config.param_dtype == expected[0] assert config.buffer_dtype == expected[1] diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 5da9b50399a94..532f0f9b8ca94 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -87,15 +87,9 @@ def step(self, model, batch): precision = self.fabric._precision assert isinstance(precision, FSDPPrecision) - if precision.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif precision.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif precision.precision == "16-true": + if precision.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif precision.precision == "bf16-true": + elif precision.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise ValueError(f"Unknown precision {precision.precision}") diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index f8731aa424b38..0834ef1f98400 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from unittest.mock import ANY, MagicMock, Mock import pytest @@ -21,19 +22,30 @@ from tests_pytorch.helpers.runif import RunIf +# Pytest passes args/kwargs to the context manager used with `pytest.warns`. +# `contextlib.nullcontext` doesn't accept them, so this no-op version does. +@contextmanager +def null_ctx(*args, **kwargs): + yield + + @pytest.mark.parametrize( ("precision", "expected"), [ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-mixed", (torch.float16, torch.float16, torch.float16)), + ("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), ("32-true", (torch.float32, torch.float32, torch.float32)), ], ) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision) - config = plugin.mixed_precision_config + + warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx + + with warning_ctx(UserWarning, match="enables computation in lower precision"): + config = plugin.mixed_precision_config assert config.param_dtype == expected[0] assert config.buffer_dtype == expected[1] diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f7c15b5930be8..1d7818c61146d 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -77,16 +77,12 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) - if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.trainer.precision == "16-true": + if self.trainer.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-true": + elif self.trainer.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.trainer.precision == "32-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise ValueError(f"Unknown precision {self.trainer.precision}") @@ -138,16 +134,12 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, torch.nn.Sequential) assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) - if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.trainer.precision == "16-true": + if self.trainer.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-true": + elif self.trainer.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 + elif self.trainer.precision == "32-true": + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise ValueError(f"Unknown precision {self.trainer.precision}") @@ -227,7 +219,7 @@ def test_strategy_sync_batchnorm(tmp_path): accelerator="gpu", devices=2, strategy="fsdp", - precision="16-mixed", + precision="32-true", max_epochs=1, sync_batchnorm=True, ) @@ -267,7 +259,7 @@ def training_step(self, batch, batch_idx): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) +@pytest.mark.parametrize("precision", ["32-true", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) @pytest.mark.parametrize("state_dict_type", ["sharded", "full"]) def test_strategy_checkpoint(state_dict_type, precision, tmp_path): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" @@ -359,7 +351,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg): accelerator="gpu", devices=2, strategy=strategy, - precision="16-mixed", + precision="32-true", max_epochs=1, limit_train_batches=2, limit_val_batches=2,