diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b58ee426a39b..8df705bac21a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -170,7 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Stopped `optimizer_zero_grad` from being called after IPU execution ([#12913](https://github.com/PyTorchLightning/pytorch-lightning/pull/12913)) -- +- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965)) ## [1.6.2] - 2022-04-27 diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index b61429264d80a..6a902d3e09a3a 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -163,7 +163,7 @@ def wrap_policy(*args, **kwargs): cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=(precision == PrecisionType.MIXED), + mixed_precision=(precision in (PrecisionType.MIXED, PrecisionType.HALF)), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, diff --git a/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 473658c5418ff..f780d88ce148b 100644 --- a/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -90,6 +90,11 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[0].reshard_after_forward is True assert self.layer.module[2].reshard_after_forward is True + if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): + assert self.layer.mixed_precision + assert self.layer.module[0].mixed_precision + assert self.layer.module[2].mixed_precision + @RunIf(min_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) def test_fully_sharded_strategy_checkpoint(tmpdir):