diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cb59d5a6bf82..c88f7835c8a36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -366,6 +366,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) +- Fixed DeepSpeed crash for RNNs ([#9489](https://github.com/PyTorchLightning/pytorch-lightning/pull/9489)) + + - Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 23eefb14ec295..3088efb9fe994 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -123,6 +123,7 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, + partition_module: bool = True, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -252,6 +253,12 @@ def __init__( load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + + partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3. + This is the default behaviour to ensure that the entire module is appropriately initialized + for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers + or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as + ``torch.nn.RNN`` which do internal logic when moving to device. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -304,6 +311,7 @@ def __init__( self.remote_device = remote_device self.load_full_weights = load_full_weights + self.partition_module = partition_module # default FP16 parameters. self.loss_scale = loss_scale @@ -374,7 +382,7 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.zero_stage_3: + if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 deepspeed.zero.Init( diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 86ae03e461850..ee82fe9538dbe 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,3 +1,4 @@ +import contextlib import json import os from typing import Any, Dict, Optional @@ -409,13 +410,15 @@ def test_deepspeed_stage_3_save_warning(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - with pytest.warns(UserWarning) as record: - # both ranks need to call save checkpoint + + # both ranks need to call save checkpoint, however only rank 0 needs to check the warning + context_manager = ( + pytest.warns(UserWarning, match="each worker will save a shard of the checkpoint within a directory.") + if trainer.is_global_zero + else contextlib.suppress() + ) + with context_manager: trainer.save_checkpoint(checkpoint_path) - if trainer.is_global_zero: - assert len(record) == 1 - match = "each worker will save a shard of the checkpoint within a directory." - assert match in str(record[0].message) @RunIf(min_gpus=1, deepspeed=True, special=True) @@ -735,7 +738,7 @@ def on_train_batch_start( @RunIf(min_gpus=2, deepspeed=True, special=True) -def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): +def test_deepspeed_multigpu_test(tmpdir): """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.""" model = ModelParallelBoringModel() trainer = Trainer( @@ -744,6 +747,57 @@ def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): trainer.test(model) +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): + """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model`` + correctly converts all parameters to float16 when ``precision=16`` and runs successfully.""" + + class TestModel(ModelParallelBoringModel): + def __init__(self): + super().__init__() + self.layer_2 = torch.nn.Linear(32, 32) + + def configure_sharded_model(self) -> None: + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + x = self.layer_2(x) + return self.layer(x) + + def on_train_epoch_start(self) -> None: + assert all([x.dtype == torch.float16 for x in self.parameters()]) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + + +@RunIf(min_gpus=1, deepspeed=True, special=True) +def test_deepspeed_multigpu_test_rnn(tmpdir): + """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when + training with certain layers which will crash with explicit partitioning.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.rnn = torch.nn.GRU(32, 32) + + def on_train_epoch_start(self) -> None: + assert all([x.dtype == torch.float16 for x in self.parameters()]) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(stage=3, partition_module=False)], + gpus=1, + fast_dev_run=True, + precision=16, + ) + trainer.fit(model) + + @RunIf(deepspeed=True) @mock.patch("deepspeed.init_distributed", autospec=True) @pytest.mark.parametrize("platform", ["Linux", "Windows"])